In [10]:
from torchvision.datasets.utils import download_url
import os
import tarfile
import hashlib
import torch.nn as nn

# https://github.com/fastai/imagenette
dataset_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
dataset_filename = dataset_url.split('/')[-1]
dataset_foldername = dataset_filename.split('.')[0]
data_path = './data'
dataset_filepath = os.path.join(data_path,dataset_filename)
dataset_folderpath = os.path.join(data_path,dataset_foldername)

os.makedirs(data_path, exist_ok=True)

download = False
if not os.path.exists(dataset_filepath):
    download = True
else:
    md5_hash = hashlib.md5()


    file = open(dataset_filepath, "rb")

    content = file.read()

    md5_hash.update(content)


    digest = md5_hash.hexdigest()
    if digest != 'fe2fc210e6bb7c5664d602c3cd71e612':
        download = True
if download:
    download_url(dataset_url, data_path)

with tarfile.open(dataset_filepath, 'r:gz') as tar:
    tar.extractall(path=data_path)

In [11]:
class DuplicatedCompose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        img1 = img.copy()
        img2 = img.copy()
        for t in self.transforms:
            img1 = t(img1)
            img2 = t(img2)
        return img1, img2

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

    def __repr__(self):
        format_string = self.__class__.__name__ + '(\n\t'
        format_string += self.base_transform.__repr__().replace('\n', '\n\t')
        format_string += '\n)'
        return format_string


In [12]:
import torchvision
import torch
from torchvision.transforms import transforms

size  = 224
ks = (int(0.1 * size) // 2) * 2 + 1 # should be odd
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
                    'std': [0.229, 0.224, 0.225]}

train_transform = DuplicatedCompose([
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.1), ratio=(0.9, 1.1), interpolation=2),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])


# train_transform = TwoCropsTransform(transforms.Compose([transforms.RandomResizedCrop(scale=(0.2, 1), size=size),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
#                                       transforms.RandomGrayscale(p=0.2),
#                                       transforms.GaussianBlur(kernel_size=ks),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize(**__imagenet_stats)]))

dataset_train = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'train'), train_transform)
dataset_test = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'test'), train_transform)
#valid_ds = ImageFolder('./data/imagenette-160/val', valid_tfms)

batch_size = 128
train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=8,
        drop_last=True,
        shuffle=True,
)


FileNotFoundError: [Errno 2] No such file or directory: './data/imagenette2/test'

In [13]:
import numpy as np

def get_numpy_samples(inputs):
        mean = torch.as_tensor(__imagenet_stats['mean'], dtype=inputs.dtype, device=inputs.device)
        std = torch.as_tensor(__imagenet_stats['std'], dtype=inputs.dtype, device=inputs.device)
        inputs = inputs * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
        inputs = inputs.numpy()
        inputs = np.transpose(inputs, (0,2,3,1))
        return inputs


In [14]:
# import matplotlib.pyplot as plt
     
# fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10,100))
# for (input1, input2), _ in train_dataloader:
#     np_inputs1, np_inputs2 = get_numpy_samples(input1), get_numpy_samples(input2)
#     for row in range(batch_size):
#         axes[row, 0].axis("off")
#         axes[row, 0].imshow(np_inputs1[row])
#         axes[row, 1].axis("off")
#         axes[row, 1].imshow(np_inputs2[row])
#     break
# plt.show()


In [15]:
from torch.nn import functional as F
def contrastive_loss(z1, z2, tau=0.2):
    N = z1.shape[0]
    logits = torch.mm(z1, z2.t())  # [N, N] pairs
    labels = torch.arange(N).cuda()  # positives are in diagonal
    loss = F.cross_entropy(logits / tau, labels)
    return 2 * tau * loss


In [16]:
import torchvision.models as models
network_q = models.resnet50()
network_k = models.resnet50()

In [17]:
K = 16384
dim = 1000

class KeysQueue():
    def __init__(self):
        self.data = torch.randn(K, dim).cuda()
    
    def enqueue(self, k):
        return torch.cat([self.data, k], dim=0)

    def dequeue(self):
        if len(self.data) > K:
            return self.data[-K:]
        else:
            return self.data
    def clone(self):
        return self.data.clone()

In [18]:
N = batch_size
C = 1000

def train(net_q, net_k, train_dataloader, my_queue):
    # add args optimizer, epoch, temp=0.07
    m = 0.9
    temp = 0.9
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net_q.parameters(), lr=1e-6)
    net_q.train()
    net_k.eval() #updates using custom function by hand
    for ((inputq, inputk), _) in train_dataloader:
        
        optimizer.zero_grad()
        x_q = net_q(inputq).cuda()
        x_k = net_k(inputk).cuda()
        x_k = x_k.detach()
        
        l_pos = torch.einsum('nc,nc->n', [x_q, x_k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [x_q, my_queue.clone().T.detach().cuda()])
        
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= temp
        
        labels = torch.zeros([logits.shape[0]]).long().cuda()
#         print(labels.shape)
#         print(logits.shape)

        loss = criterion(logits, labels)
        print(loss)
        
        loss.backward()
        optimizer.step()
        
        for p_q, p_k in zip(net_q.parameters(), net_k.parameters()):
            p_k.data.mul_(m).add_(1 - m, p_q.detach().data)
        
        my_queue.enqueue(x_k)
        my_queue.dequeue()


image_queue = KeysQueue()
train(network_q, network_k, train_dataloader, image_queue)

tensor(341.5904, device='cuda:0', grad_fn=<NllLossBackward>)


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554793803/work/torch/csrc/utils/python_arg_parser.cpp:1005.)
  p_k.data.mul_(m).add_(1 - m, p_q.detach().data)


tensor(33.1511, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(60.2457, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(64.1586, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(61.6973, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(68.3275, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(63.5275, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(62.7126, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(62.8217, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(62.5484, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(61.9553, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(58.0762, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(50.4404, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(44.7630, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(39.8367, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(29.9959, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(14.3732, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(7.5563, device='cuda:0', grad_fn=

KeyboardInterrupt: 

In [25]:
for param in network_q.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = 1000
network_q.fc_gal = nn.Linear(num_ftrs, 10).cuda()
network_q = network_q.cuda()


In [26]:
def train_after(net_q, train_dataloader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net_q.parameters(), lr=1e-6)
    net_q.train()
    for ((inputq, _), labels) in train_dataloader:
        optimizer.zero_grad()
        print(inputq.shape)
        x_q = net_q(inputq.cuda())
        logits = net_q.fc_gal(x_q)

        loss = criterion(logits, labels)
        print(loss)
        
        loss.backward()
        optimizer.step()
        


train_after(network_q, train_dataloader)

torch.Size([32, 3, 224, 224])


RuntimeError: mat1 dim 1 must match mat2 dim 0