In [None]:
! pip install kornia

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import torch
import  torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt
from torchvision.datasets import VOCDetection
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

In [None]:
from kornia.geometry.transform import get_perspective_transform, warp_perspective
from skimage import transform as trf

In [None]:
dataset = VOCDetection(root='data/',year ='2012',
                       download=True, transform=ToTensor())
test_dataset = VOCDetection(root='data/', year = '2012',
                            image_set='val', transform=ToTensor())

Using downloaded and verified file: data/VOCtrainval_11-May-2012.tar
Extracting data/VOCtrainval_11-May-2012.tar to data/


In [None]:
def get_categories(labels_dir):
    if not os.path.isdir(labels_dir):
        raise FileNotFoundError
    else:
        categories = []
        for file in os.listdir(labels_dir):
            if file.endswith("_train.txt"):
                categories.append(file.split("_")[0])
        return categories

object_categories = get_categories('./data/VOCdevkit/VOC2012/ImageSets/Main')

def encode_labels(target):
    ls = target['annotation']['object']
    j = []
    if type(ls) == dict:
        if int(ls['difficult']) == 0:
            j.append(object_categories.index(ls['name']))
    else:
        for i in range(len(ls)):
            if int(ls[i]['difficult']) == 0:
                j.append(object_categories.index(ls[i]['name']))
    k = np.zeros(len(object_categories))
    k[j] = 1
    return torch.from_numpy(k)

def denormalize(images, means, stds):
    means = torch.tensor(means).reshape(1, 3, 1, 1)
    stds = torch.tensor(stds).reshape(1, 3, 1, 1)
    return images * stds + means

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        denorm_images = denormalize(images, *stats)
        ax.imshow(make_grid(denorm_images[:64], nrow=8).permute(1, 2, 0).clamp(0,1))
        break

In [None]:
# Data transforms (normalization & data augmentation)
stats = ((0.45704722, 0.43824774, 0.4061733),(0.23908591, 0.23509644, 0.2397309))
train_tfms=tt.Compose([tt.Resize((256, 256)),
                       tt.RandomChoice([tt.ColorJitter(brightness=(0.80, 1.20)),
                                        tt.RandomGrayscale(p = 0.25)]),
                       tt.RandomHorizontalFlip(p = 0.25),
                       tt.ToTensor(), 
                       tt.Normalize(*stats,inplace=True)])

valid_tfms = tt.Compose([tt.Resize(256), 
                         tt.CenterCrop(200),
                         tt.ToTensor(), 
                         tt.Normalize(*stats)])

In [None]:
train_ds = VOCDetection(root='data/', year ='2012',
                        download=True,
                        transform=train_tfms, target_transform=encode_labels)
test_dataset = VOCDetection(root='data/', year = '2012',
                            image_set='val',
                            transform=valid_tfms,target_transform=encode_labels)

Using downloaded and verified file: data/VOCtrainval_11-May-2012.tar
Extracting data/VOCtrainval_11-May-2012.tar to data/


In [None]:
torch.manual_seed(43)
val_size = 4500
batch_size = 16

In [None]:
test_size = len(test_dataset) - val_size
test_ds, val_ds = random_split(test_dataset, [test_size, val_size])
len(test_ds), len(val_ds)

(1323, 4500)

In [None]:
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size, num_workers=4, pin_memory=True)

In [None]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.cuda.is_available()

True

In [None]:
train_dl = DeviceDataLoader(train_dl, torch.device('cpu'))
val_dl   = DeviceDataLoader(val_dl, torch.device('cpu'))
test_dl  = DeviceDataLoader(test_dl, torch.device('cpu'))

In [None]:
HEIGHT = WIDTH = 256
AR = 0.8

In [None]:
def cumulative_norm(x):
    y = torch.sum(x, -1).view(-1, 3, 256, 1)
    x = torch.cumsum(x, -1)
    return x / y

In [None]:
def inverse_warp(input, flow):
    shape = input.shape     # [1, 3, 256, 256]
    i_H = shape[2] # 256
    i_W = shape[3] # 256
    shape = flow.shape      # [1, 3, 256, 408]
    N = shape[1] # 3
    H = shape[2] # 256
    W = shape[3] # 408

    N_i = torch.range(0, N - 1)  # [0, ..., 1]
    W_i = torch.range(0, W - 1)  # [0, ..., 408]
    H_i = torch.range(0, H - 1)  # [0, ..., 256]

    n, h, w = torch.meshgrid(N_i, H_i, W_i, indexing='ij')
    n = n.repeat(1,1,1,1)  # torch.unsqueeze(n, dim=-1)  # [1, 3, 256, 408]
    h = h.repeat(1,1,1,1)  # [1, 3, 256, 408]
    w = w.repeat(1,1,1,1)  # [1, 3, 256, 408]

    n = n.double()
    h = h.double()
    w = w.double()

    # print('flow:', flow.shape)  # [1, 3, 256, 408]
    v_col, v_row = torch.split(flow, [204, 204], dim=-1)
    # print('v_col', v_col.shape)  # [1, 3, 256, 204]
    # print('v_row', v_row.shape)  # [1, 3, 256, 204]
    """ calculate index """
    v_r0 = torch.floor(v_row)
    v_r1 = v_r0 + 1
    v_c0 = torch.floor(v_col)
    v_c1 = v_c0 + 1

    H_ = float(i_H - 1)
    W_ = float(i_W - 1)
    
    i_r0 = torch.clamp(v_r0, 0., H_)
    i_r1 = torch.clamp(v_r1, 0., H_)
    i_c0 = torch.clamp(v_c0, 0., W_)
    i_c1 = torch.clamp(v_c1, 0., W_)

    i_r0c0 = torch.concat([n.cuda(), i_r0.cuda(), i_c0.cuda()], -1).double() # [N, H, W, 3]
    i_r0c1 = torch.concat([n.cuda(), i_r0.cuda(), i_c1.cuda()], -1).double()
    i_r1c0 = torch.concat([n.cuda(), i_r1.cuda(), i_c0.cuda()], -1).double()
    i_r1c1 = torch.concat([n.cuda(), i_r1.cuda(), i_c1.cuda()], -1).double()
    print('i_r0c0', i_r0c0.shape)
    """ take value from index """
    f00 = torch.index_select(input, 0, i_r0c0) # [N, H, W, C]
    f01 = torch.index_select(input, -1, i_r0c1)
    f10 = torch.index_select(input, -1, i_r1c0)
    f11 = torch.index_select(input, -1, i_r1c1)

    """ calculate coeff """
    w00 = (v_r1 - v_row) * (v_c1 - v_col)
    w01 = (v_r1 - v_row) * (v_col - v_c0)
    w10 = (v_row - v_r0) * (v_c1 - v_col)
    w11 = (v_row - v_r0) * (v_col - v_c0)

    out = w00 * f00 + w01 * f01 + w10 * f10 + w11 * f11
    return out

In [None]:
class MySkimageTransform(object):
    def __init__(self):
        '''
        initialize your transformation here, if necessary
        '''
    def __call__(self, pic):
        batch = pic.cpu().detach().numpy()
        afine_tf = trf.AffineTransform(shear=0.4,
                                       translation=[(1-AR) * WIDTH, 0])
        for img in batch:
            # apply your transformation
            img = trf.warp(img, inverse_map=afine_tf)
        return batch

In [None]:
torch.cuda.empty_cache()

In [None]:
class RetargetModel(torch.nn.Module):
    def __init__(self):
        super(RetargetModel, self).__init__()
        self.vgg16_model = torchvision.models.vgg16(pretrained=True)
        self.vgg16_model.classifier[6] = nn.Linear(in_features=4096,
                                                   out_features=20, bias=True)
        self.decoder = nn.Sequential(
            nn.Upsample(size=(16, 16)),
            nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            # nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            # nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),

            nn.Upsample(size=(32, 32)),
            nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            # nn.ConvTranspose2d(512, 512, kernel_size=(3, 3)), nn.ELU(),

            nn.Upsample(scale_factor=(64, 64)),
            nn.ConvTranspose2d(512, 256, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            nn.ConvTranspose2d(256, 256, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((1, 1)),
            # nn.ConvTranspose2d(256, 256, kernel_size=(3, 3)), nn.ELU(),

            nn.Upsample(scale_factor=(128, 128)),
            nn.ConvTranspose2d(256, 128, kernel_size=(3, 3)), nn.ELU(),
            # nn.ConvTranspose2d(128, 128, kernel_size=(3, 3)), nn.ELU(),

            nn.Upsample(size=(HEIGHT, WIDTH)),
            nn.ConvTranspose2d(128, 64, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((HEIGHT - 1, WIDTH - 1)),
            nn.ConvTranspose2d(64, 3, kernel_size=(3, 3)), nn.ELU(),
            # tt.CenterCrop((HEIGHT - 1, WIDTH - 1)),
        )
        
        self.resize = tt.Resize((HEIGHT, int(AR*WIDTH)))
        self.revert_resize = tt.Resize((HEIGHT, WIDTH))

        self.conv1d = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=(1, HEIGHT),
                      padding='valid'),
            nn.ReLU()
        )
        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 16, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(16, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(24 * 16 * 7 * 7, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )
        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0],
                                                    dtype=torch.float))

    def forward(self, input_image):
        encoded_image = self.vgg16_model.features(input_image)
        print('encded: ', encoded_image.size())
        output_image = self.decoder(encoded_image)
        print('decoded: ', output_image.size())
        resized_image = self.resize(output_image)    # [32, 3, 256, 204]
        
        # DUPLICATE LAYER IMPLEMENTATION
        duplicate_image = self.conv1d(output_image)    # [32, 3, 256, 1]
        duplicate_image = duplicate_image.repeat(1, 1, 1, int(AR*WIDTH)) # [32, 3, 256, 204]
        output_map = torch.add(resized_image, duplicate_image) # [1, 3, 256, 204]
        
        cum_map = torch.cumsum(output_map, -1) # [1, 3, 256, 204]
        cum_image = torch.cumsum(input_image, -1) # [1, 3, 256, 256]

        # NORM LAYER IMPLEMENTATION
        output_map = np.abs(WIDTH - int(AR*WIDTH)) * cumulative_norm(output_map)
        
        # WARP LAYER IMPLEMENTATION
        xs = self.localization(output_map)
        xs = xs.view(-1, 24 * 16 * 7 * 7)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, output_map.size(), align_corners=True)
        resized_image = F.grid_sample(resized_image, grid)

        reverted_sized_image = self.revert_resize(resized_image)
        output_label = self.vgg16_model(reverted_sized_image)

        # getting structure loss
        a1 = self.unet_model.encoder2(self.unet_model.encoder1(input_image))
        a2 = self.vgg16_model.features[0](reverted_sized_image)
        return resized_image, output_label, a1, a2

In [None]:
model = RetargetModel()
model = model.to(device)
out = model(torch.zeros((1, 3, 256, 256)).cuda())
print(out[2].size(), out[3].size())
plt.imshow(out[0].cpu().detach().numpy()[0][0])

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
learning_rate = 1e-5
# Create adam optimizer
optimizer=torch.optim.Adam(params=model.parameters(), lr=learning_rate)

In [None]:
ContentLoss = nn.CrossEntropyLoss()
# ContentLoss = nn.BCEWithLogitsLoss()
StructureLoss = nn.CosineLoss()

In [None]:
from tqdm import tqdm

for epoch in tqdm(range(500)):
    running_loss = 0.0
    # loss = None
    for i, data in enumerate(train_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        images, predictions, a1, a2 = model(inputs)
        content_loss = ContentLoss(predictions, labels)
        structure_loss = StructureLoss(a1, a2)
        avg = torch.tensor(0.5, requires_grad=True)
        loss = (1 - avg) * content_loss + avg * structure_loss
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss = running_loss + float(content_loss.item()) + float(structure_loss.item())
        if i % 100 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0
            
            plt.imshow(denormalize(inputs[0].cpu(), *stats).permute(0, 2, 3, 1).clamp(0,1).detach().numpy()[0])
            plt.show()
            plt.imshow(denormalize(images[0].cpu(), *stats).permute(0, 2, 3, 1).clamp(0,1).detach().numpy()[0])
            plt.show()
        # if epoch % 50 == 0:
        #     PATH = "model.pt"
        #     torch.save({
        #         'epoch': epoch,
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'loss': loss,
        #         }, PATH)
print('Finished Training')

In [None]:
content_loss, structure_loss

In [None]:
denormalize(images[0].cpu(), *stats).size()

In [None]:
torch.permute(denormalize(images[0].cpu(), *stats), (0, 2,3,1)).size()

In [None]:
plt.imshow(denormalize(images[0].cpu(), *stats).permute(0, 2, 3, 1).clamp(0,1).detach().numpy()[0])
plt.show()

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)