In [None]:
import numpy as np
import pandas as pd
import os
import shutil

import torchvision
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler, TensorDataset
from torch.utils import data
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models
from torchvision.utils import save_image
from tqdm import tqdm, tqdm_notebook
from sklearn.metrics import f1_score
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import time
import copy
import random
from itertools import combinations

#from easyfsl.data_tools  import TaskSampler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(os.listdir("../input"))

import zipfile
with zipfile.ZipFile('../input/platesv2/plates.zip', 'r') as zip_obj:
    #Extract all the contents of zip file in current directory
    zip_obj.extractall('/kaggle/working/')
    
print('After zip extraction:')
print(os.listdir("/kaggle/working/"))

# Utils

In [None]:
detector = models.detection.maskrcnn_resnet50_fpn(pretrained=True).to(DEVICE)
detector.eval()

detect_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

def circle_crop(img):
    try:
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, 1.5, gray.shape[0] / 4,
                                               param1=300, param2=20, minRadius=70)
        x, y, r = circles[0, 0].astype('int32')
        zoom = 1
        #half_a = int(r * zoom/(2 ** 0.5))
        half_a = int(r * zoom)
        if (half_a > y) | (half_a + y > img.shape[0]) | (half_a > x) | (half_a + x > img.shape[1]):
            raise Exception()
            
        mask1 = np.zeros_like(img)
        mask1 = cv2.circle(mask1, (x,y), r, (1,1,1), -1)
        
        res = img * mask1
        #Get the background
        background = img - res
        #Change all pixels in the background that are not black to white
        background[np.where((background > [0, 0, 0]).all(axis = 2))] = mean * 255
        #Add the background and the image
        res = background + res
        
        res = res[y-half_a: y+half_a, x-half_a: x+half_a, :]
        
        return (res,0)
    except:
        tensor = detect_transforms(img).to(DEVICE)
        pred = detector(tensor.unsqueeze(0))
        boxes = pred[0]['boxes']
        masks = pred[0]['masks']
        sqare = ((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))
        if len(sqare) != 0:
 
            boxes = boxes[sqare > 5000]
            masks = masks[sqare > 5000]
            x1, y1, x2, y2 =boxes[0].data.cpu().numpy().astype('int')

            y2, y1 = img.shape[0] - y2, img.shape[0] - y1
        
            res = img * masks[0].permute(1, 2, 0).data.cpu().numpy().round()
            background = img - res
            #Change all pixels in the background that are not black to mean color
            background[np.where((background > [0, 0, 0]).all(axis = 2))] = mean * 255
            #Add the background and the image
            res = background + res
            #res = img
            return (res[y2:y1, x1:x2, :],1)
        else: 
            return img,2
        
def crop_and_save(input_dir, output_dir, error_list):
    for path in tqdm(os.listdir(input_dir)):
        if path != '.DS_Store':
            img = cv2.imread(os.path.join(input_dir, path))
            img,check = circle_crop(img)
            if check == 2:
                erorrs.append(path)
            cv2.imwrite(os.path.join(output_dir, path),img)  
            
def image_fromtensor(input_tensor):
    image = input_tensor.permute(1, 2, 0).data.cpu().numpy()
    image = image*std + mean
    return image
def show_input(input_tensor, title=''):
    image = image_fromtensor(input_tensor)
    plt.imshow(image)
    plt.title(title)
    plt.show()
    plt.pause(0.001)


In [None]:
def calc_sims(query_dataloader, support_dataloader, models = [], dist ='l2'):
    sims = [] #N_Models x len(_q) x len(_s)
    support_labels = []
    with torch.set_grad_enabled(False):
        for q_X, q_y in tqdm(query_dataloader):
            q_X= q_X.to(DEVICE)
            q_batch_size = q_y.shape[0]
            sims_level1 = [] #N_Models x Batch_q x len(_s)
            for s_X, s_y in support_dataloader:
                s_X= s_X.to(DEVICE)
                support_labels.append(s_y.data)
                s_batch_size = s_y.shape[0]
                sims_level2 = [] #N_Models x Batch_q x Batch_s
                for m in models:
                    q_z = m(q_X) #Batch_q x EmbSize
                    s_z = m(s_X) #Batch_s x EmbSize

                    q_z /= torch.linalg.norm(q_z, dim=1,  keepdim=True)
                    s_z /= torch.linalg.norm(s_z, dim=1,  keepdim=True)

                    if dist =='l2':
                        dists = torch.cdist(q_z, s_z)
                    else:
                        dists = 1 - torch.mm(q_z, torch.t(s_z))
                    sims_level2.append(dists)
                sims_level2 = torch.cat(sims_level2).view(-1, q_batch_size, s_batch_size)
                sims_level1.append(sims_level2)
            sims_level1 = torch.cat(sims_level1, axis=-1)
            sims.append(sims_level1)
        sims = torch.cat(sims, dim =1)
    return sims, torch.cat(support_labels)

def show_nearest(dataset_q, dataset_s, ind,  knn=5, start_index =0 , num_imgs = 10):
    for i in range(num_imgs):
        ii = start_index + i
        q_img, q_y = dataset_q[ii]
        fig, axs = plt.subplots(1,knn + 1,figsize=(20,12))
        axs[0].imshow(image_fromtensor(q_img))
        s_imgs =[ dataset_s[id][0] for id  in ind[ii,:knn]]
        s_labels = [ dataset_s.targets[id] for id  in ind[ii,:knn]]
        for ax_id in range(1, len(axs)):
            axs[ax_id].imshow(image_fromtensor(s_imgs[ax_id-1]))
            axs[ax_id].set_title(s_labels[ax_id-1])

# Data

In [None]:
from typing import List, Tuple

combinations_of_classes = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] # list(combinations(range(4), 2))
weights_of_classes = [(0.3, 0.7), (0.5, 0.5), (0.3, 0.7), (0.7, 0.3), (0.5, 0.5), (0.3, 0.7)]
merged_classes = [ 1, 2, 3, 3, 3, 3]

class PlatesDataset(Dataset):
    def __init__(self, data: List[Tuple[torch.Tensor, int]], transform=None, mix_up = False):
        super().__init__()
        if mix_up:
            mixed = []
            for couple in combinations(data, 2):
                sample1, sample2  = sorted(couple, key= lambda item: item[1])
                
                comb_ids = (sample1[1],  sample2[1])
                if comb_ids in combinations_of_classes:
                    id = combinations_of_classes.index(comb_ids)
                    weights = weights_of_classes[id]
                    merged_class = merged_classes[id]
                else:
                    weights = (0.5, 0.5)
                    merged_class = sample1[1]
                    
                one_mix = (sample1[0] * weights[0] + sample2[0] * weights[1], merged_class )
                mixed.append(one_mix)
            for d in data:
                mixed.append(d)
            data = mixed
            
        self.data = data            
        self.transform = transform
        self.targets = [y for X, y in data]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        if self.transform:
            x = self.transform(self.data[index][0])
        else:
            x = self.data[index][0]
        y = self.data[index][1]
        return x, y
        


class TaskSampler(Sampler):
    """
    Samples batches in the shape of few-shot classification tasks. At each iteration, it will sample
    n_way classes, and then sample support and query images from these classes.
    """

    def __init__(
        self, dataset: Dataset, n_way: int, n_shot: int, n_query: int, n_tasks: int
    ):
        """
        Args:
            dataset: dataset from which to sample classification tasks. Must have a field 'label': a
                list of length len(dataset) containing containing the labels of all images.
            n_way: number of classes in one task
            n_shot: number of support images for each class in one task
            n_query: number of query images for each class in one task
            n_tasks: number of tasks to sample
        """
        super().__init__(data_source=None)
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_tasks = n_tasks

        self.items_per_label = {}
        assert hasattr(
            dataset, "targets"
        ), "TaskSampler needs a dataset with a field 'label' containing the labels of all images."
        for item, label in enumerate(dataset.targets):
            if label in self.items_per_label.keys():
                self.items_per_label[label].append(item)
            else:
                self.items_per_label[label] = [item]

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            yield torch.cat(
                [
                    # pylint: disable=not-callable
                    
                        torch.tensor(
                        random.sample(
                            self.items_per_label[label], (self.n_shot + self.n_query) )
                        )
                       
                    
                    
                    # pylint: enable=not-callable
                    for label in range(self.n_way) 
                ]
            )

    def episodic_collate_fn(
        self, input_data: List[Tuple[torch.Tensor, int]]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
        """
        Collate function to be used as argument for the collate_fn parameter of episodic
            data loaders.
        Args:
            input_data: each element is a tuple containing:
                - an image as a torch Tensor
                - the label of this image
        Returns:
            tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
                - support images,
                - their labels,
                - query images,
                - their labels,
                - the dataset class ids of the class sampled in the episode
        """

        true_class_ids = list({x[1] for x in input_data})

        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data])
        all_images = all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        )
        # pylint: disable=not-callable
        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in input_data]
        ).reshape((self.n_way, self.n_shot + self.n_query))
        # pylint: enable=not-callable

        support_images = all_images[:, : self.n_shot].reshape(
            (-1, *all_images.shape[2:])
        )
        query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:]))
        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()

        return (support_images, support_labels,query_images), query_labels

# Models

In [None]:
class Net(nn.Module):
    def __init__(self, freeze_bn, freeze_bn_affine, fc_backbone):
        super().__init__()
        
        self.freeze_bn = freeze_bn
        self.freeze_bn_affine = freeze_bn_affine
        
        self.ce_loss = nn.CrossEntropyLoss()
        
        #self.backbone = models.inception_v3(pretrained = True, aux_logits=False)
        self.backbone = models.resnet50(pretrained = True)
        #for param in self.backbone.parameters():
            #param.requires_grad = False
        #for param in self.backbone.layer4.parameters():
            #param.requires_grad = True
        if fc_backbone:
            self.backbone.fc = fc_backbone
            
        
        
            
    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super().train(mode)
        if self.freeze_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    if self.freeze_bn_affine:
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False
            #self.backbone.layer4.train()
            
    def forward(self):
        pass
    def loss_function(self, outputs, labels):
        ce = self.ce_loss(outputs, labels)
        return ce


class Proto_Net(Net):
    
    def __init__(self, freeze_bn = False, freeze_bn_affine = False, fc_backbone = False, dist='l2',  ):
        super().__init__(freeze_bn = freeze_bn, freeze_bn_affine = freeze_bn_affine, fc_backbone = fc_backbone)
        self.dist = dist
        
                        
    def forward(self, x ):
        support, support_labels, query = x
        z_support = self.backbone(support)
        z_query = self.backbone(query)
        z_support = z_support/torch.linalg.norm(z_support, dim=1,  keepdim=True)
        z_query = z_query/torch.linalg.norm(z_query, dim=1,  keepdim=True)
        
        # Infer the number of classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all support features vector with label i
        self.z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )
        
        # Compute the euclidean distance from queries to prototypes
        if self.dist == 'l2':
            dists = torch.cdist(z_query, self.z_proto)
            scores = -dists
            
        # Compute the cos similarity from queries to prototypes   
        elif self.dist == 'cos':
            scores = torch.mm(z_query, torch.t(self.z_proto)) 
        
        return scores
        
            
            
class Clf_Net(Net):
    def __init__(self, freeze_bn = False, freeze_bn_affine = False, fc_backbone = False,   ):
        
        super().__init__(freeze_bn = freeze_bn, freeze_bn_affine = freeze_bn_affine, fc_backbone = fc_backbone)
        
        self.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(2048,2))
        
    def forward(self, X):
        X = self.backbone(X)
        
        return self.fc(X)

# Configs

In [None]:
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

data_root = '/kaggle/working/plates/'

batch_size = 18
learning_rate = 0.0001
num_epoch = 30
input_size = 224

num_classes = 2 # 2 or 4
N_SHOT = 5
N_QUERY = 5
N_TASKS = 5

mix_up = True
remove_strange_plates = True
l2_reg = 0.001

std=np.array([0.229, 0.224, 0.225])
mean=np.array([0.485, 0.456, 0.406])

class_names = ['cleaned', 'dirty']

if num_classes == 4:
    new_class_names = {'cleaned': 'patern_cleaned', 'dirty': 'patern_dirty'} 
    patern_ids ={'cleaned':  ['0006.jpg', '0007.jpg', '0008.jpg', '0009.jpg', '0011.jpg', '0013.jpg', '0016.jpg'], 
                 'dirty': ['0000.jpg', '0001.jpg', '0006.jpg', '0010.jpg', '0011.jpg', '0015.jpg', '0019.jpg']}  
    

In [None]:
#crop_train
#cleaned
erorrs = []
path = os.path.join(data_root, 'train/cleaned')
crop_and_save(path, path, erorrs)

#dirty
path = os.path.join(data_root, 'train/dirty')
crop_and_save(path, path, erorrs)

erorrs

In [None]:
#remove strange plates and repeats
if remove_strange_plates:
    !rm '/kaggle/working/plates/train/cleaned/0000.jpg'
    !rm '/kaggle/working/plates/train/cleaned/0004.jpg'
    !rm '/kaggle/working/plates/train/cleaned/0006.jpg'
    !rm '/kaggle/working/plates/train/dirty/0001.jpg'
    !rm '/kaggle/working/plates/train/dirty/0003.jpg'
    !rm '/kaggle/working/plates/train/dirty/0004.jpg'
    !rm '/kaggle/working/plates/train/dirty/0010.jpg'
    !rm '/kaggle/working/plates/train/dirty/0012.jpg'

In [None]:
#devide data into 4 classes
if num_classes == 4:
    for cl_name in class_names:   
        path = os.path.join(data_root, 'train', cl_name)
        new_path = os.path.join(data_root, 'train', new_class_names[cl_name])
        if not os.path.exists(new_path):
            os.mkdir(new_path)
        for file_name in os.listdir(path):
            file_name = file_name
            if file_name in patern_ids[cl_name]:
                shutil.move(os.path.join(path, file_name) , new_path)
        

In [None]:
#crop test 
!rm -r test
inpath = '/kaggle/working/plates/test'
outpath = '/kaggle/working/test/unknown'

os.mkdir('/kaggle/working/test/')
os.mkdir(outpath)

erorrs=[]
crop_and_save(inpath, outpath, erorrs)
        
len(erorrs)

In [None]:
#transforms

shuffle_RGB = transforms.Lambda(lambda X: X[torch.randperm(3)])

train_transforms = transforms.Compose([
    transforms.RandomRotation(360),
    transforms.Resize(int(input_size * 1.3)),
    transforms.CenterCrop(input_size),
    #transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.05,0.05,0.1,0.1),
    #transforms.ToTensor(),
    transforms.RandomApply([shuffle_RGB], p=0.3),
    transforms.Normalize(std=std, mean=mean),
])


base_transforms = transforms.Compose([
    transforms.Resize(int(input_size * 1.5)),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(std=std, mean=mean),
])




In [None]:
#test_data

test_dir = 'test'
if not os.path.exists(os.path.join(test_dir, 'unknown')):
    shutil.copytree(os.path.join(data_root, 'test'), os.path.join(test_dir, 'unknown'))

test_dataset = torchvision.datasets.ImageFolder('/kaggle/working/test', base_transforms)
paths =[ p for p, i in test_dataset.imgs]

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
#train_val data

base_dataset = torchvision.datasets.ImageFolder('plates/train', transform = base_transforms)
train_dataset = PlatesDataset(base_dataset, train_transforms, mix_up = mix_up )
    

sampler = TaskSampler( train_dataset, n_way = num_classes, n_shot =  N_SHOT, n_query =  N_QUERY, n_tasks = N_TASKS)
clf_dataloader = DataLoader(train_dataset, batch_size = 16, num_workers=2, shuffle = True)
fsl_dataloader =  DataLoader(train_dataset, batch_sampler=sampler, num_workers=2,  collate_fn = sampler.episodic_collate_fn)


In [None]:
iter_data = iter(clf_dataloader)
for i in range(2):
    q_img, q_y =next(iter_data)
    grid_q = torchvision.utils.make_grid(q_img, scale_each=False, normalize=True)
    fig, ax = plt.subplots(figsize=(20,12),)
    
    plt.imshow(grid_q.permute(1, 2, 0).cpu().numpy())

In [None]:
import copy
import collections


def fit_epoch(model, train_loader, optimizer):
    model.train()
    running_loss = 0.0
    running_f1 = 0
  
    for X, y in train_loader:
        
        if isinstance(X, collections.Sequence):
            X = (x.to(DEVICE) for x in X)
        else:
            X = X.to(DEVICE)
            
        y = y.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(X)
        loss = model.loss_function(outputs, y)
        
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        running_f1 += f1_score(y.cpu(), outputs.argmax(dim=1).cpu(), average ='micro')
              
    train_loss = running_loss / len(train_loader)
    train_f1 = running_f1 / len(train_loader)
    return train_loss, train_f1
  
def eval_epoch(model, val_loader):
    model.eval()
    running_loss = 0.0
    running_f1 = 0
    for X, y in val_loader:
        
        if isinstance(X, collections.Sequence):
            X = (x.to(DEVICE) for x in X)
        else:
            X = X.to(DEVICE)
            
        y = y.to(DEVICE)

        with torch.set_grad_enabled(False):
            outputs = model(X)
            loss = model.loss_function(outputs, y)


        running_loss += loss.item()
        running_f1 += f1_score(y.cpu(), outputs.argmax(dim=1).cpu(), average ='micro')
        

    val_loss = running_loss / len(val_loader)
    val_f1 = running_f1 / len(val_loader)
    return val_loss, val_f1
  
def train(train_loader, val_loader, model, epochs, l2_reg=0):
    history = []
    models = []
    n_models = epochs // 5 if (epochs // 5) != 0 else epochs
    best_model = None
    best_loss = float('inf')
    best_f1 = 0
    log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} val_loss {v_loss:0.4f} f1val {v_f1:0.4f}"
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay  = l2_reg)
    #sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    
    with tqdm(desc="epoch", total=epochs) as pbar_outer:        
        for epoch in range(epochs):

            train_loss, train_f1 = fit_epoch(model, train_loader, optimizer)
            
            #print("loss", train_loss)            
            val_loss, val_f1 = eval_epoch(model, val_loader)
            sched.step(val_loss)
            if (best_loss > val_loss):
                best_loss = val_loss
                best_model = copy.deepcopy(model)
                print('new best model')
            #if (best_f1 < train_f1):
                #best_f1 = train_f1
                #best_model = copy.deepcopy(model)
            if epoch % n_models == 0:
                models.append(model)
            history.append([epoch, train_f1, val_f1])            
            pbar_outer.update(1)

            tqdm.write(log_template.format(ep=epoch+1, t_loss=train_loss, v_loss=val_loss, v_f1=val_f1))            
    return history, best_model, models

In [None]:
fc = nn.Dropout(0.2)
model_0 = Proto_Net(freeze_bn = True, freeze_bn_affine = True, dist='l2', fc_backbone= fc ).to(DEVICE)
h, b_model_0, proto_models = train(fsl_dataloader, fsl_dataloader, model_0 , num_epoch, l2_reg=l2_reg)
#proto_models = [m.backbone for m in proto_models ]
#proto_models.append(b_model_0.backbone)
#proto_models.append(model_0.backbone)

In [None]:
model_0.eval()
b_model_0.eval()
ds = base_dataset
support = DataLoader(ds, batch_size=16, shuffle = False)
sims, support_labels = calc_sims(test_dataloader, support, models=[b_model_0.backbone, model_0.backbone], dist ='l2')
ind = sims.mean(0).argsort(1)
print(base_dataset.classes)
show_nearest(test_dataset, ds, ind,  knn=5, start_index =500 , num_imgs = 20)
plt.show()

In [None]:
test_preds = np.repeat(np.array(ds.targets)[None,:],744, axis=0)
test_preds = np.take_along_axis(test_preds, ind.cpu().numpy(), axis=1)[:,:7]

test_preds = [np.argmax(np.bincount(t)) for t in test_preds]


In [None]:
if num_classes == 4:
    class_names *= 2
test_preds = [ class_names[t] for t in test_preds]

submission = pd.DataFrame({'label': test_preds, 'id': paths})
#submission['label'] = submission['label'].map(lambda pred: 'dirty' if pred > 0.5 else 'cleaned')
submission['id'] = submission['id'].str.replace('/kaggle/working/test/unknown/', '')
submission['id'] = submission['id'].str.replace('.jpg', '')
submission.set_index('id', inplace=True)
submission[submission.label == 'dirty'].count()



In [None]:
submission.to_csv('submission.csv')

In [None]:
!rm -r test plates