In [1]:
from torchvision.datasets.utils import download_url
import os
import tarfile
import hashlib
import torch.nn as nn
import copy
import math


# https://github.com/fastai/imagenette
dataset_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
dataset_filename = dataset_url.split('/')[-1]
dataset_foldername = dataset_filename.split('.')[0]
data_path = './data'
dataset_filepath = os.path.join(data_path,dataset_filename)
dataset_folderpath = os.path.join(data_path,dataset_foldername)

os.makedirs(data_path, exist_ok=True)

download = False
if not os.path.exists(dataset_filepath):
    download = True
else:
    md5_hash = hashlib.md5()


    file = open(dataset_filepath, "rb")

    content = file.read()

    md5_hash.update(content)


    digest = md5_hash.hexdigest()
    if digest != 'fe2fc210e6bb7c5664d602c3cd71e612':
        download = True
if download:
    download_url(dataset_url, data_path)

with tarfile.open(dataset_filepath, 'r:gz') as tar:
    tar.extractall(path=data_path)
    
with open("tmp.txt",'w') as tmp:
    tmp.write("hello")

In [2]:
class DuplicatedCompose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        img1 = img.copy()
        img2 = img.copy()
        for t in self.transforms:
            img1 = t(img1)
            img2 = t(img2)
        return img1, img2

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

    def __repr__(self):
        format_string = self.__class__.__name__ + '(\n\t'
        format_string += self.base_transform.__repr__().replace('\n', '\n\t')
        format_string += '\n)'
        return format_string


In [3]:
import torchvision
import torch
from torchvision.transforms import transforms

size  = 224
ks = (int(0.1 * size) // 2) * 2 + 1 # should be odd
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
                    'std': [0.229, 0.224, 0.225]}

train_transform = DuplicatedCompose([
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.1), ratio=(0.9, 1.1), interpolation=2),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])


# train_transform = TwoCropsTransform(transforms.Compose([transforms.RandomResizedCrop(scale=(0.2, 1), size=size),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
#                                       transforms.RandomGrayscale(p=0.2),
#                                       transforms.GaussianBlur(kernel_size=ks),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize(**__imagenet_stats)]))

dataset_train = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'train'), train_transform)
dataset_test = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'val'), train_transform)
#valid_ds = ImageFolder('./data/imagenette-160/val', valid_tfms)

batch_size = 128
train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=8,
        drop_last=True,
        shuffle=True,
)




In [4]:
import numpy as np

def get_numpy_samples(inputs):
        mean = torch.as_tensor(__imagenet_stats['mean'], dtype=inputs.dtype, device=inputs.device)
        std = torch.as_tensor(__imagenet_stats['std'], dtype=inputs.dtype, device=inputs.device)
        inputs = inputs * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
        inputs = inputs.numpy()
        inputs = np.transpose(inputs, (0,2,3,1))
        return inputs


In [5]:
# import matplotlib.pyplot as plt
     
# fig, axes = plt.subplots(nrows=batch_size, ncols=2, figsize=(10,100))
# for (input1, input2), _ in train_dataloader:
#     np_inputs1, np_inputs2 = get_numpy_samples(input1), get_numpy_samples(input2)
#     for row in range(batch_size):
#         axes[row, 0].axis("off")
#         axes[row, 0].imshow(np_inputs1[row])
#         axes[row, 1].axis("off")
#         axes[row, 1].imshow(np_inputs2[row])
#     break
# plt.show()


In [6]:
from torch.nn import functional as F
def contrastive_loss(z1, z2, tau=0.2):
    N = z1.shape[0]
    logits = torch.mm(z1, z2.t())  # [N, N] pairs
    labels = torch.arange(N).cuda()  # positives are in diagonal
    loss = F.cross_entropy(logits / tau, labels)
    return 2 * tau * loss


In [31]:
import torchvision.models as models
network_q = models.resnet50()
network_k = models.resnet50()

In [32]:
K = 800
dim = 1000

class KeysQueue():
    def __init__(self):
        self.data = torch.randn(K, dim).cuda()
    
    def enqueue(self, k):
        return torch.cat([self.data, k], dim=0)

    def dequeue(self):
        if len(self.data) > K:
            return self.data[-K:]
        else:
            return self.data
    def clone(self):
        return self.data.clone()

In [33]:
N = batch_size
C = 1000

def train(net_q, net_k, train_dataloader, my_queue):
    # add args optimizer, epoch, temp=0.07
    i = 0
    total_loss = 0
    avg_loss = math.inf
    m = 0.999
    temp = 0.07
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net_q.parameters(), lr=3e-4)
    net_q.train()
    net_k.eval() #updates using custom function by hand
    for ((inputq, inputk), _) in train_dataloader:
        i+=1
#         inputq = inputq.cuda()
#         inputk = inputk.cuda()
        
        optimizer.zero_grad()
        
        x_q = net_q(inputq).cuda()
        x_q = nn.functional.normalize(x_q, dim=1)
        x_k = net_k(inputk).cuda()
        x_k = nn.functional.normalize(x_k, dim=1)
        x_k = x_k.detach()
        
        l_pos = torch.einsum('nc,nc->n', [x_q, x_k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [x_q, my_queue.clone().T.detach().cuda()])
        
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= temp
        
        labels = torch.zeros([logits.shape[0]]).long().cuda()
#         print(labels.shape)
#         print(logits.shape)

        loss = criterion(logits, labels)
        print(loss)
        
        loss.backward()
        optimizer.step()
        
        for p_q, p_k in zip(net_q.parameters(), net_k.parameters()):
            p_k.data.copy_(m*p_k.data + p_q.data*(1-m))
        
        my_queue.enqueue(x_k)
        my_queue.dequeue()
        
        total_loss += loss.item()
        if avg_loss > total_loss / i:
            avg_loss = total_loss / i
            torch.save(net_q.state_dict(),"./Untitled Folder/model_q_epoch_"+str(i)+".pt")
            torch.save(net_k.state_dict(),"./Untitled Folder/model_k_epoch_"+str(i)+".pt")



image_queue = KeysQueue()
train(network_q, network_k, train_dataloader, image_queue)

tensor(40.3756, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(40.4764, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(36.9782, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(33.2539, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(31.3089, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(29.1649, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(27.1785, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(25.2685, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(24.0952, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(22.8473, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(21.8308, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(20.9320, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(20.1435, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(19.4464, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(18.7094, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(18.0100, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(17.3301, device='cuda:0', grad_fn

KeyboardInterrupt: 

In [34]:
pretext_network = copy.deepcopy(network_q)

In [35]:
for param in pretext_network.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = 1000
pretext_network.fc_gal = nn.Linear(num_ftrs, 10).cuda()
pretext_network = pretext_network.cuda()


In [39]:
def train_after(net_q, train_dataloader):
    
    criterion = nn.CrossEntropyLoss()
    parameters = list(filter(lambda p: p.requires_grad, net_q.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias
    print(parameters)
    optimizer = torch.optim.Adam(parameters, lr=0.0003)
    #optimizer = torch.optim.SGD(net_q.parameters(), lr=1e-6)
    net_q.train()
    
    for epoch in range(10):
        acc = 0
        total_loss = 0
        for ((inputq, _), labels) in train_dataloader:
            labels = labels.cuda()
            optimizer.zero_grad()
            x_q = net_q(inputq.cuda())
            logits = net_q.fc_gal(x_q)

            loss = criterion(logits, labels)
            pred = torch.argmax(logits,dim=-1)

            acc += (labels == pred).sum() / (batch_size * len(train_dataloader))
            total_loss+=loss

            loss.backward()
            optimizer.step()

#         acc = acc / len(train_dataloader)
        total_loss = total_loss /  len(train_dataloader)
        print("acc: "+str(acc))
        print("loss: "+str(total_loss))



train_after(pretext_network, train_dataloader)

[Parameter containing:
tensor([[ 0.0241, -0.0078, -0.0203,  ..., -0.0073,  0.0152,  0.0056],
        [-0.0136, -0.0071, -0.0314,  ...,  0.0148,  0.0081, -0.0238],
        [ 0.0262,  0.0196, -0.0162,  ...,  0.0108, -0.0262, -0.0032],
        ...,
        [ 0.0385,  0.0258, -0.0101,  ..., -0.0251,  0.0211,  0.0172],
        [ 0.0287, -0.0068, -0.0168,  ..., -0.0284,  0.0338, -0.0061],
        [-0.0124, -0.0162, -0.0245,  ..., -0.0490,  0.0121, -0.0016]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([ 0.0279,  0.0086,  0.0355, -0.0089,  0.0179, -0.0104,  0.0194,  0.0150,
         0.0053, -0.0115], device='cuda:0', requires_grad=True)]
acc: tensor(0.2102, device='cuda:0')
loss: tensor(2.2088, device='cuda:0', grad_fn=<DivBackward0>)
acc: tensor(0.2156, device='cuda:0')
loss: tensor(2.1780, device='cuda:0', grad_fn=<DivBackward0>)
acc: tensor(0.2148, device='cuda:0')
loss: tensor(2.1724, device='cuda:0', grad_fn=<DivBackward0>)
acc: tensor(0.2176, device='cuda:0