In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [20]:
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import random

In [21]:
'''
Implementation of Barlow Twins (https://arxiv.org/abs/2103.03230), adapted for ease of use for experiments from
https://github.com/facebookresearch/barlowtwins, with some modifications using code from 
https://github.com/lucidrains/byol-pytorch
'''

def flatten(t):
    return t.reshape(t.shape[0], -1)

class NetWrapper(nn.Module):

    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, x):
        representation = self.get_representation(x)

        return representation



def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [22]:
class BarlowTwins(nn.Module):
    '''
    Adapted from https://github.com/facebookresearch/barlowtwins for arbitrary backbones, and arbitrary choice of which
    latent representation to use. Designed for models which can fit on a single GPU (though training can be parallelized
    across multiple as with any other model). Support for larger models can be done easily for individual use cases by
    by following PyTorch's model parallelism best practices.
    '''
    def __init__(self, backbone, latent_id, projection_sizes, lambd, scale_factor=1):
        '''
        
        
        :param backbone: Model backbone
        :param latent_id: name (or index) of the layer to be fed to the projection MLP
        :param projection_sizes: size of the hidden layers in the projection MLP
        :param lambd: tradeoff function
        :param scale_factor: Factor to scale loss by, default is 1
        '''
        super().__init__()
        self.backbone = backbone
        self.backbone = NetWrapper(self.backbone, latent_id)
        self.lambd = lambd
        self.scale_factor = scale_factor
        # projector
        sizes = projection_sizes
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.backbone(y1)
        z2 = self.backbone(y2)
        z1 = self.projector(z1)
        z2 = self.projector(z2)

        # empirical cross-correlation matrix
        c = torch.mm(self.bn(z1).T, self.bn(z2))
        c.div_(z1.shape[0])


        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor*(on_diag + self.lambd * off_diag)
        return loss

In [23]:
'''
#####
Adapted from https://github.com/facebookresearch/barlowtwins
#####
'''
class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class Transform:
    def __init__(self, transform=None, transform_prime=None):
        '''

        :param transform: Transforms to be applied to first input
        :param transform_prime: transforms to be applied to second
        '''
        if transform == None:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform
        if transform_prime == None:

            self.transform_prime = transforms.Compose([
                transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.1),
                Solarization(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform_prime = transform_prime

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2

In [39]:
from torch.utils.data import random_split, DataLoader

dataset = dsets.CIFAR10(root="./data", train=True, transform=Transform(), download=True)
total_size = len(dataset)
train_size = int(0.6 * total_size)
fine_tune_size = int(0.2 * total_size)
test_size = total_size - train_size - fine_tune_size

train_dataset, fine_tune_dataset, test_dataset = random_split(dataset, [train_size, fine_tune_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

print("Unlabeled size:", len(train_dataset))
print("Fine-tune size:", len(fine_tune_dataset))
print("Test size:", len(test_dataset))

Files already downloaded and verified
Unlabeled size: 30000
Fine-tune size: 10000
Test size: 10000


In [None]:
model = torchvision.models.resnet18(zero_init_residual=True)
    

learner = BarlowTwins(model, 'avgpool', [512,1024, 1024, 1024],3.9e-3, 1)

optimizer = torch.optim.Adam(learner.parameters(), lr=0.001)

#Single training epoch
cnt=0
for batch_idx, ((x1,x2), _) in enumerate(train_dataloader):
    loss = learner(x1, x2)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'Batch {batch_idx} Done')

Batch 0 Done
Batch 1 Done
Batch 2 Done
Batch 3 Done
Batch 4 Done
Batch 5 Done
Batch 6 Done
Batch 7 Done
Batch 8 Done
Batch 9 Done
Batch 10 Done
Batch 11 Done
Batch 12 Done
Batch 13 Done
Batch 14 Done
Batch 15 Done
Batch 16 Done
Batch 17 Done
Batch 18 Done
Batch 19 Done
Batch 20 Done
Batch 21 Done
Batch 22 Done
Batch 23 Done
Batch 24 Done
Batch 25 Done
Batch 26 Done
Batch 27 Done
Batch 28 Done
Batch 29 Done
Batch 30 Done
Batch 31 Done
Batch 32 Done
Batch 33 Done
Batch 34 Done
Batch 35 Done
Batch 36 Done
Batch 37 Done
Batch 38 Done
Batch 39 Done
Batch 40 Done
Batch 41 Done
Batch 42 Done
Batch 43 Done
Batch 44 Done
Batch 45 Done
Batch 46 Done
Batch 47 Done
Batch 48 Done
Batch 49 Done
Batch 50 Done
Batch 51 Done
Batch 52 Done
Batch 53 Done
Batch 54 Done
Batch 55 Done
Batch 56 Done
Batch 57 Done
Batch 58 Done
Batch 59 Done
Batch 60 Done
Batch 61 Done
Batch 62 Done
Batch 63 Done
Batch 64 Done
Batch 65 Done
Batch 66 Done
Batch 67 Done
Batch 68 Done
Batch 69 Done
Batch 70 Done
Batch 71 Done
Ba

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    running_loss = 0.
    #total_sample = 0
    for batch_i, data in enumerate(dataloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
#         inputs, labels = data[0], data[1]
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # Compute prediction and loss
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        num_batch_print = 200
        if batch_i % num_batch_print == num_batch_print - 1:
            current = (batch_i + 1) * len(inputs)
            print(f'batch: [{batch_i + 1:5d}], loss: {running_loss / num_batch_print:.3f}')
            print(f'temporary loss: {loss.item():>7f} | [{current:>5d}/{size:>5d}]')
            running_loss = 0

In [None]:
def test_loop(dataloader, model, loss_fn):
    model.eval()
    #size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    test_loss, correct = 0, 0
    total = 0
    
    with torch.no_grad():
        for data in dataloader:
            # get the inputs for test dataset
            images, labels = data[0].to(device), data[1].to(device)
#             images, labels = data[0], data[1]

            
            # calculate the outputs
            outputs = model(images)
            
            # classify which class the output in
            _, predicted = torch.max(outputs.data, 1)
            #_, labels_value = torch.max(labels.data, 1)
            
            # obtain the statistics of test loss and correctness
            test_loss += loss_fn(outputs, labels).item()
            correct += (predicted == labels).sum().item() 
            #correct += (predicted == labels_value).sum().item() 
            total += labels.size(0)

    print(f"Test Error: \n Accuracy: {(100 * correct / total):>0.1f}%")
    print(f"Avg loss: {test_loss / num_batches:>8f} \n")

In [None]:
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 2, device=device)

model.trainable = True

set_trainable = False

for layer in model.layers:
    if layer.name == 'avgpool':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False

for layer in conv_base.layers:
    print(layer.name,layer.trainable)

model = model.to(device)

criterion = nn.CrossEntropyLoss()

In [None]:
# new
# criterion = nn.CrossEntropyLoss()
lr_list = [1e-2]*5
time_start = time.time()
# Pre-trained ResNet training
for t in range( len(lr_list) ):
    print(f"-------------Epoch {t+1}-------------")
    #optimizer = optim.SGD(model.parameters(), lr=lr_list[t], momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=lr_list[t])
    train_loop(train_dataloader, model, criterion, optimizer)
    test_loop(test_dataloader, model, criterion)
print("Over")
time_end = time.time()
print("Time Consumption",time_end-time_start)