In [1]:
from torchvision.datasets.utils import download_url
import os
import tarfile
import hashlib
import torch.nn as nn
import copy
import math
import torchvision.models as models
import timeit
import random
from PIL import ImageFilter
device = "cuda:0"


# https://github.com/fastai/imagenette
dataset_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
dataset_filename = dataset_url.split('/')[-1]
dataset_foldername = dataset_filename.split('.')[0]
data_path = './data'
dataset_filepath = os.path.join(data_path,dataset_filename)
dataset_folderpath = os.path.join(data_path,dataset_foldername)

os.makedirs(data_path, exist_ok=True)

download = False
if not os.path.exists(dataset_filepath):
    download = True
else:
    md5_hash = hashlib.md5()


    file = open(dataset_filepath, "rb")

    content = file.read()

    md5_hash.update(content)


    digest = md5_hash.hexdigest()
    if digest != 'fe2fc210e6bb7c5664d602c3cd71e612':
        download = True
if download:
    download_url(dataset_url, data_path)

with tarfile.open(dataset_filepath, 'r:gz') as tar:
    tar.extractall(path=data_path)
    
with open("tmp.txt",'w') as tmp:
    tmp.write("hello")

In [10]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/moco')
%tensorboard --logdir logs

# !pip3 install tensorboard

UsageError: Line magic function `%tensorboard` not found.


In [11]:
class DuplicatedCompose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        img1 = img.copy()
        img2 = img.copy()
        for t in self.transforms:
            img1 = t(img1)
            img2 = t(img2)
        return img1, img2

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

    def __repr__(self):
        format_string = self.__class__.__name__ + '(\n\t'
        format_string += self.base_transform.__repr__().replace('\n', '\n\t')
        format_string += '\n)'
        return format_string
    
    
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

In [28]:
import torchvision
import torch
from torchvision.transforms import transforms



image_size = 224
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


augmentation = [
    transforms.RandomResizedCrop(image_size, scale=(0.2, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
]

train_transform = TwoCropsTransform(transforms.Compose(augmentation))

dataset_train = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'train'), train_transform)
dataset_test = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'val'), train_transform)

batch_size = 64
train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=4,
        drop_last=True,
        shuffle=True,
)


In [29]:
class Moco2(nn.Module):
    def __init__(self, encoder):
        super(Moco2, self).__init__()
        self.moco = encoder
        self.hidden_dim = encoder.fc.weight.shape[1]
        self.linear = nn.Linear(self.hidden_dim,self.hidden_dim)
        self.activation = nn.ReLU()
        self.fc = nn.Sequential(self.linear,self.activation,encoder.fc)
        self.moco.fc = self.fc
        
    def forward(self, x):
        return self.moco(x)

In [30]:
# network_q = models.resnet50(num_classes=128)
# network_q = network_q.to(device)
# network_k = models.resnet50(num_classes=128)

network_q = Moco2(models.resnet50(num_classes=128))
network_q.load_state_dict(torch.load("./Untitled Folder/new_model_q_epoch_720.pt"))
network_q = network_q.to(device)
network_k = Moco2(models.resnet50(num_classes=128))
network_k.load_state_dict(torch.load("./Untitled Folder/new_model_k_epoch_720.pt"))


<All keys matched successfully>

In [31]:
K = 896
dim = 128

class KeysQueue():
    def __init__(self):
        self.data = torch.randn(K, dim).to(device)
        self.queue_ptr = 0
#     def enqueue(self, k):
#         self.data =  torch.cat([self.data, k], dim=0)

#     def dequeue(self):
#         if len(self.data) > K:
#             self.data = self.data[-K:]
#         else:
#             return
#     def clone(self):
#         return self.data.clone()

    def _dequeue_and_enqueue(self, keys):
#         gather keys before updating queue
#         keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = self.queue_ptr
        assert K % batch_size == 0  # for simplicity

        self.data[ptr:ptr + batch_size] = keys
        ptr = (ptr + batch_size) % K  # move pointer

        self.queue_ptr = ptr

In [32]:
# !pip3 install --upgrade wandb==0.10.8

In [33]:
import time
N = batch_size
C = 128
def train(net_q, net_k, train_dataloader, my_queue):
    # add args optimizer, epoch, temp=0.07
    avg_loss = math.inf
    m = 0.99
    temp = 0.02
    num_epochs = 2000
    net_k = net_k.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net_q.parameters(), lr=3e-4)
#     optimizer = torch.optim.SGD(net_q.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-6)
    net_q.train()    
    for epoch in range (720, num_epochs):
        total_loss = 0

        print(f"Starting epoch {epoch} !\n")
        for ((inputq, inputk), _) in train_dataloader:
            optimizer.zero_grad()


            x_q = net_q(inputq.to(device)).to(device)
            x_q = nn.functional.normalize(x_q, dim=1)
            x_k = None

            with torch.no_grad():
                x_k = net_k(inputk.to(device)).to(device)
                x_k = nn.functional.normalize(x_k, dim=1)
#                 x_k = x_k.detach()

            l_pos = torch.bmm(x_q.view(N,1,C), x_k.view(N,C,1)).squeeze(-1)
            l_neg = torch.mm(x_q.view(N,C), my_queue.data.clone().T.detach())


            logits = torch.cat([l_pos, l_neg], dim=1)
            logits /= temp

            labels = torch.zeros([logits.shape[0]]).long().to(device)
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()
            with torch.no_grad():
                for p_q, p_k in zip(net_q.parameters(), net_k.parameters()):
                    p_k.data.copy_(m*p_k.data + p_q.data*(1-m))

            my_queue._dequeue_and_enqueue(x_k)
#             my_queue.enqueue(x_k)
#             my_queue.dequeue()

            total_loss += loss.item()
        print(f"Loss : {total_loss / len(train_dataloader)}")
        writer.add_scalar('Moco loss',
                            total_loss / len(train_dataloader),
                            epoch)

        if avg_loss > total_loss / (epoch + 1):
            avg_loss = total_loss / (epoch + 1)
            torch.save(net_q.state_dict(),"./Untitled Folder/new_model_q_epoch_"+str(epoch)+".pt")
            torch.save(net_k.state_dict(),"./Untitled Folder/new_model_k_epoch_"+str(epoch)+".pt")



image_queue = KeysQueue()
for param in network_k.parameters():
    param.requires_grad = False

train(network_q, network_k, train_dataloader, image_queue)

Starting epoch 720 !

Loss : 9.828462841762166
Starting epoch 721 !

Loss : 0.7842963674441487
Starting epoch 722 !

Loss : 0.8265370349494778
Starting epoch 723 !

Loss : 0.7687063673321082
Starting epoch 724 !

Loss : 0.7704377614316487
Starting epoch 725 !

Loss : 0.783128634804771
Starting epoch 726 !

Loss : 0.7531797535159961
Starting epoch 727 !

Loss : 0.7391796706079625
Starting epoch 728 !

Loss : 0.7844102469836773
Starting epoch 729 !

Loss : 0.7617013651092036
Starting epoch 730 !

Loss : 0.7825598593066339
Starting epoch 731 !

Loss : 0.7762201004693298
Starting epoch 732 !

Loss : 0.7667141599314553
Starting epoch 733 !

Loss : 0.7802735668461339
Starting epoch 734 !

Loss : 0.7723750510588795
Starting epoch 735 !

Loss : 0.7660526692056332
Starting epoch 736 !

Loss : 0.797857942021623
Starting epoch 737 !

Loss : 0.7744176215460511
Starting epoch 738 !

Loss : 0.773586788753263
Starting epoch 739 !

Loss : 0.760027639111694
Starting epoch 740 !

Loss : 0.79648577497929

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7705439424433675
Starting epoch 762 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7888864303121761
Starting epoch 763 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7593545925860502
Starting epoch 764 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7556078496838914
Starting epoch 765 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7431464272291481
Starting epoch 766 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7589884357792991
Starting epoch 767 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7788409462996891
Starting epoch 768 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.6962699032559687
Starting epoch 769 !



Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>Traceback (most recent call last):

  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
        self._shutdown_workers()
self._shutdown_workers()  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers

  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)    
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)  File "/home/guy.shapira/miniconda3/envs/inv/lib/pyt

Loss : 0.7555094430235778
Starting epoch 770 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7327621187888035
Starting epoch 771 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7138067212234549
Starting epoch 772 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join

Traceback (most recent call last):
      File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
assert self._parent_pid == os.getpid(), 'can only join a child process'
    AssertionErrorself._shutdown_workers(): 
can only join a child process  File "/home/g

Loss : 0.7596369610757244
Starting epoch 773 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7179522676532771
Starting epoch 774 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7374315146280794
Starting epoch 775 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7583241109945336
Starting epoch 776 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7213829593593571
Starting epoch 777 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7492334159458576
Starting epoch 778 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7371237251628824
Starting epoch 779 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
assert self._parent_pid == os.getpid(), 'can only join a child process'  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__

    AssertionErrorself._shutdown_workers(): 
can only join a child process  File "/home/g

Loss : 0.7390756819929395
Starting epoch 780 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.760055526786921
Starting epoch 781 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.731624361609115
Starting epoch 782 !



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4ffdb2b560>
Traceback (most recent call last):
  File "/home/guy.shapira/miniconda3/envs/inv/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/home/

Loss : 0.7048228312106359
Starting epoch 783 !

Loss : 0.7317454993319349
Starting epoch 784 !

Loss : 0.7311820706137183
Starting epoch 785 !

Loss : 0.7141827588178673
Starting epoch 786 !

Loss : 0.7308354045258088
Starting epoch 787 !

Loss : 0.7167634325368064
Starting epoch 788 !

Loss : 0.7392583051506354
Starting epoch 789 !

Loss : 0.7134620048561875
Starting epoch 790 !

Loss : 0.7257950502593501
Starting epoch 791 !

Loss : 0.7250279421303548
Starting epoch 792 !

Loss : 0.7091347267027615
Starting epoch 793 !

Loss : 0.7664200361488628
Starting epoch 794 !

Loss : 0.7034972695266308
Starting epoch 795 !

Loss : 0.7367635201029226
Starting epoch 796 !

Loss : 0.7309630742283906
Starting epoch 797 !

Loss : 0.685928385071203
Starting epoch 798 !

Loss : 0.7141605297318933
Starting epoch 799 !

Loss : 0.7198294111255075
Starting epoch 800 !

Loss : 0.6857879672731672
Starting epoch 801 !

Loss : 0.7318890510367698
Starting epoch 802 !

Loss : 0.724220195595099
Starting epoch 8

Starting epoch 954 !

Loss : 0.5576116155807663
Starting epoch 955 !

Loss : 0.52868520139026
Starting epoch 956 !

Loss : 0.5718388618255148
Starting epoch 957 !

Loss : 0.5509556100076559
Starting epoch 958 !

Loss : 0.5526843623441904
Starting epoch 959 !

Loss : 0.5615994848158895
Starting epoch 960 !

Loss : 0.5580108042071465
Starting epoch 961 !

Loss : 0.5716348787148794
Starting epoch 962 !

Loss : 0.5457124280280807
Starting epoch 963 !

Loss : 0.5519044138744574
Starting epoch 964 !

Loss : 0.5811557957307011
Starting epoch 965 !

Loss : 0.5732317934839093
Starting epoch 966 !

Loss : 0.577310569432317
Starting epoch 967 !

Loss : 0.5637615799498396
Starting epoch 968 !

Loss : 0.5556383217070379
Starting epoch 969 !

Loss : 0.5596842054201632
Starting epoch 970 !

Loss : 0.5493642965547082
Starting epoch 971 !

Loss : 0.5636287231226357
Starting epoch 972 !

Loss : 0.5449746795454804
Starting epoch 973 !

Loss : 0.5520699318168926
Starting epoch 974 !

Loss : 0.549479419885

Loss : 0.4419246619047762
Starting epoch 1123 !

Loss : 0.42689649953323155
Starting epoch 1124 !

Loss : 0.4449350626087513
Starting epoch 1125 !

Loss : 0.4465734961689735
Starting epoch 1126 !

Loss : 0.46658164607424313
Starting epoch 1127 !

Loss : 0.4492215857822068
Starting epoch 1128 !

Loss : 0.43490687793209437
Starting epoch 1129 !

Loss : 0.43038704585866866
Starting epoch 1130 !

Loss : 0.4459332940327067
Starting epoch 1131 !

Loss : 0.4107447083101792
Starting epoch 1132 !

Loss : 0.45573259347758327
Starting epoch 1133 !

Loss : 0.45729536770963347
Starting epoch 1134 !

Loss : 0.46286948207689793
Starting epoch 1135 !

Loss : 0.43629357867500407
Starting epoch 1136 !

Loss : 0.4168001284607414
Starting epoch 1137 !

Loss : 0.4224006998397055
Starting epoch 1138 !

Loss : 0.40953458936846987
Starting epoch 1139 !

Loss : 0.4521467480935207
Starting epoch 1140 !

Loss : 0.4448705679502617
Starting epoch 1141 !

Loss : 0.43312249996629704
Starting epoch 1142 !

Loss : 0.4

Loss : 0.3618179008263309
Starting epoch 1289 !

Loss : 0.38286317084111327
Starting epoch 1290 !

Loss : 0.3812312063108496
Starting epoch 1291 !

Loss : 0.36128804270102055
Starting epoch 1292 !

Loss : 0.39040897330459284
Starting epoch 1293 !

Loss : 0.3780036634531151
Starting epoch 1294 !

Loss : 0.3716204382327138
Starting epoch 1295 !

Loss : 0.3246467129206982
Starting epoch 1296 !

Loss : 0.35869326695901194
Starting epoch 1297 !

Loss : 0.3787984922528267
Starting epoch 1298 !

Loss : 0.3821662509826576
Starting epoch 1299 !

Loss : 0.3593366388358226
Starting epoch 1300 !

Loss : 0.3822185557310273
Starting epoch 1301 !

Loss : 0.3685501353168974
Starting epoch 1302 !

Loss : 0.3760151812920765
Starting epoch 1303 !

Loss : 0.37007982593004396
Starting epoch 1304 !

Loss : 0.38248811124944365
Starting epoch 1305 !

Loss : 0.3870775050350598
Starting epoch 1306 !

Loss : 0.40723104055235987
Starting epoch 1307 !

Loss : 0.38750581352078184
Starting epoch 1308 !

Loss : 0.383

Loss : 0.3302072312454788
Starting epoch 1455 !

Loss : 0.31000284516081517
Starting epoch 1456 !

Loss : 0.3231293506362811
Starting epoch 1457 !

Loss : 0.3265129803192048
Starting epoch 1458 !

Loss : 0.3103939559893543
Starting epoch 1459 !

Loss : 0.320852150316952
Starting epoch 1460 !

Loss : 0.3144619105016293
Starting epoch 1461 !

Loss : 0.31353114758219036
Starting epoch 1462 !

Loss : 0.33511662047331026
Starting epoch 1463 !

Loss : 0.3316678420520153
Starting epoch 1464 !

Loss : 0.31946558231601907
Starting epoch 1465 !

Loss : 0.31132481556360414
Starting epoch 1466 !

Loss : 0.305146630726704
Starting epoch 1467 !

Loss : 0.3169097017218061
Starting epoch 1468 !

Loss : 0.3258873347421082
Starting epoch 1469 !

Loss : 0.31728439879457965
Starting epoch 1470 !

Loss : 0.34219101192999857
Starting epoch 1471 !

Loss : 0.31453570084912436
Starting epoch 1472 !

Loss : 0.3148994311994436
Starting epoch 1473 !

Loss : 0.32436025953617226
Starting epoch 1474 !

Loss : 0.3204

Loss : 0.2647425607607073
Starting epoch 1621 !

Loss : 0.2749956376394447
Starting epoch 1622 !

Loss : 0.2771313609618719
Starting epoch 1623 !

Loss : 0.27909107460659377
Starting epoch 1624 !

Loss : 0.2873751575342652
Starting epoch 1625 !

Loss : 0.2824402002452993
Starting epoch 1626 !

Loss : 0.24921690322914902
Starting epoch 1627 !

Loss : 0.2870144164379762
Starting epoch 1628 !

Loss : 0.2868137491398117
Starting epoch 1629 !

Loss : 0.2758449509131665
Starting epoch 1630 !

Loss : 0.28647203621815664
Starting epoch 1631 !

Loss : 0.2590323603731029
Starting epoch 1632 !

Loss : 0.2638340712505944
Starting epoch 1633 !

Loss : 0.2624354557622047
Starting epoch 1634 !

Loss : 0.28767341096587734
Starting epoch 1635 !

Loss : 0.3002013628985606
Starting epoch 1636 !

Loss : 0.2785132013920213
Starting epoch 1637 !

Loss : 0.2813253433627336
Starting epoch 1638 !

Loss : 0.25142285772630957
Starting epoch 1639 !

Loss : 0.26641306336842424
Starting epoch 1640 !

Loss : 0.27829

KeyboardInterrupt: 

In [35]:
pretext_network = Moco2(models.resnet50(num_classes=128))
pretext_network.load_state_dict(torch.load("./Untitled Folder/new_model_q_epoch_1544.pt"))


<All keys matched successfully>

In [36]:
from collections import OrderedDict


for param in pretext_network.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = 128
classifier = nn.Sequential(OrderedDict([
    ('fc1_a', nn.Linear(num_ftrs, 100)),
    ('added_relu1_a', nn.ReLU(inplace=True)),
    ('fc2_a', nn.Linear(100, 50)),
    ('added_relu2_a', nn.ReLU(inplace=True)),
    ('fc3_a', nn.Linear(50, 10))
]))
pretext_network.fc_lin = classifier



# pretext_network.fc_gal = nn.Linear(num_ftrs, 10).to(device)
pretext_network = pretext_network.to(device)


In [None]:
import timeit

batch_size = 128
train_dataloader_2 = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=8,
        drop_last=True,
        shuffle=True,
)

def train_after(net_q, train_dataloader):
    
    criterion = nn.CrossEntropyLoss()
    parameters = list(filter(lambda p: p.requires_grad, net_q.parameters()))
#     assert len(parameters) == 2  # fc.weight, fc.bias
    optimizer = torch.optim.Adam(parameters, lr=3e-4)
#     optimizer = torch.optim.SGD(parameters, lr=0.1, momentum=0.9, weight_decay=1e-6)
    net_q.train()
    
    for epoch in range(1000):
        acc = 0
        total_loss = 0
        print(f"Starting epoch {epoch} !\n")
        for ((inputq, _), labels) in train_dataloader:
            
            time0 = timeit.default_timer()
            
            labels = labels.to(device)
            optimizer.zero_grad()
            x_q = net_q(inputq.to(device)).to(device)
            logits = net_q.fc_lin(x_q)
            
            time1 = timeit.default_timer()

            loss = criterion(logits, labels)
            pred = torch.argmax(logits,dim=-1)

            acc += (labels == pred).sum() / (batch_size * len(train_dataloader))
            total_loss+=loss
            

            loss.backward()
            optimizer.step()



        total_loss = total_loss /  len(train_dataloader)
        print("acc: "+str(acc))
        print("loss: "+str(total_loss))
        writer.add_scalar('Moco linear loss',
                    total_loss / len(train_dataloader),
                    epoch)

        if 1:
#             avg_loss = total_loss / (epoch + 1)
            torch.save(net_q.state_dict(),"./Untitled Folder/finished_model"+str(epoch)+".pt")



train_after(pretext_network, train_dataloader_2)

Starting epoch 0 !

acc: tensor(0.6118, device='cuda:0')
loss: tensor(18.4260, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 1 !

acc: tensor(0.6841, device='cuda:0')
loss: tensor(8.5645, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 2 !

acc: tensor(0.6744, device='cuda:0')
loss: tensor(5.5370, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 3 !

acc: tensor(0.6350, device='cuda:0')
loss: tensor(3.4422, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 4 !

acc: tensor(0.5574, device='cuda:0')
loss: tensor(2.0352, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 5 !

acc: tensor(0.5787, device='cuda:0')
loss: tensor(1.6543, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 6 !

acc: tensor(0.5864, device='cuda:0')
loss: tensor(1.5037, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 7 !

acc: tensor(0.6261, device='cuda:0')
loss: tensor(1.3682, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 8 !

acc: tensor(0.6483, device=

Starting epoch 69 !

acc: tensor(0.8344, device='cuda:0')
loss: tensor(0.5352, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 70 !

acc: tensor(0.8368, device='cuda:0')
loss: tensor(0.5372, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 71 !

acc: tensor(0.8279, device='cuda:0')
loss: tensor(0.5361, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 72 !

acc: tensor(0.8334, device='cuda:0')
loss: tensor(0.5434, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 73 !

acc: tensor(0.8298, device='cuda:0')
loss: tensor(0.5457, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 74 !

acc: tensor(0.8281, device='cuda:0')
loss: tensor(0.5550, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 75 !

acc: tensor(0.8365, device='cuda:0')
loss: tensor(0.5173, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 76 !

acc: tensor(0.8363, device='cuda:0')
loss: tensor(0.5231, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 77 !

acc: tensor(0.8325,

Starting epoch 137 !

acc: tensor(0.8480, device='cuda:0')
loss: tensor(0.4692, device='cuda:0', grad_fn=<DivBackward0>)
Starting epoch 138 !



[34m[1mwandb[0m: Currently logged in as: [33mguyshapira[0m (use `wandb login --relogin` to force relogin)
