In [1]:
import os
import typing as t
import seaborn as sn
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm_notebook
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import SGD
from torchvision.models.resnet import resnet18
from torchvision.transforms import ToTensor, Compose
import numpy as np
from sklearn.metrics import accuracy_score
from ido_cv import draw_images
from implementation.gridmix_pytorch import GridMixAugLoss

In [16]:
class SampleDataset(Dataset):    
    classes = {
        0: [255, 0, 0],
        1: [0, 255, 0],
        2: [0, 0, 255],
        3: [0, 255, 255]
    }
    
    def __init__(self, num_samples: 1000):
        self.class_list = [np.random.randint(min(self.classes), max(self.classes)) for _ in range(num_samples)]    
        self.augs = ToTensor()

    def __getitem__(self, idx: int) -> t.Dict[str, np.array]:
        label = self.class_list[idx]
        image = np.ones(shape=(224, 224, 3), dtype=np.uint8) * self.classes[label]
        image_tensor = self.augs(image)
        sample = {            
            'image': image_tensor,
            'target': label
        }
        return sample

    def __len__(self) -> int:
        return len(self.class_list)
    
class SimpleModel(nn.Module):
    def __init__(self, out_classes: int = 4):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, out_classes)
    
    def forward(self, x):        
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
def accuracy(preds: torch.Tensor, trues: torch.Tensor) -> float:
    lam = trues[-1, :][0].data.cpu().numpy()[0]
    true_label = [trues[0, :].long(), trues[1, :].long()]
    trues = true_label[0] if lam > 0.5 else true_label[1]
    trues = np.squeeze(trues.data.cpu().numpy(), axis=1).astype(np.uint8)
    preds = torch.softmax(preds, dim=1).float()
    preds = np.argmax(preds.data.cpu().numpy(), axis=1).astype(np.uint8)
    metric = accuracy_score(trues, preds)
    return float(metric)

In [17]:
train_dataset = SampleDataset(num_samples=1000)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    num_workers=1,
    shuffle=False
)
model = SimpleModel().to("cuda")
optimizer = SGD(params=model.parameters(), lr=0.0001)
criterion = GridMixAugLoss(
    alpha=(0.4, 0.7),
    hole_aspect_ratio=1.,
    crop_area_ratio=(0.5, 1),
    crop_aspect_ratio=(0.5, 2),
    n_holes_x=(2, 6)
)

In [18]:
model.train()
for epoch in range(5):    
    # Set progressbar
    tq = tqdm_notebook(total=(len(train_dataloader.dataset)))
    tq.set_description(f"Epoch {epoch}")
    
    # Init metric lists
    losses = []
    metrics = []
    
    # Iter over dataset
    for n_iter, data in enumerate(train_dataloader):
        # Get images and targets
        images = data['image'].float().to("cuda")
        labels = data['target'].unsqueeze(-1).to("cuda")        
        
        # Apply GridMix
        inputs, targets = criterion.get_sample(images=images, targets=labels)  
        
        # Predict
        preds = model(inputs)            
        
        # Calculate accuracy
        acc = accuracy(preds, targets)
        metrics.append(acc)
        
        # Calculate loss
        loss = criterion(preds, targets)   
        losses.append(loss.item())
        
        # Steps
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Visualize progressbar
        tqdm_parameters = {'loss': np.mean(losses), 'metric': np.mean(metrics)}
        tq.update(inputs.shape[0])
        tq.set_postfix(**tqdm_parameters)
        
    

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))





HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))