In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [10]:
def NT_xentloss_(z1, z2, temperature=0.5): 
    N, Z = z1.shape 
    device = z1.device 
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)

    l_pos = torch.diag(similarity_matrix, N)
    r_pos = torch.diag(similarity_matrix, -N)
    positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

    diag = torch.eye(2*N, dtype=torch.bool, device=device)
    diag[N:,:N] = diag[:N,N:] = diag[:N,:N]
    negatives = similarity_matrix[~diag].view(2*N, -1)

    logits = torch.cat([positives, negatives], dim=1) / temperature
    labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample
    loss = F.cross_entropy(logits, labels, reduction='sum')

    return loss / (2 * N)

class NT_xentloss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NT_xentloss, self).__init__()
        self.temperature = temperature
        
    def forward(self, z1, z2):
        N, Z = z1.shape 
        device = z1.device 
        representations = torch.cat([z1, z2], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)

        l_pos = torch.diag(similarity_matrix, N)
        r_pos = torch.diag(similarity_matrix, -N)
        positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

        diag = torch.eye(2*N, dtype=torch.bool, device=device)
        diag[N:,:N] = diag[:N,N:] = diag[:N,:N]
        negatives = similarity_matrix[~diag].view(2*N, -1)

        logits = torch.cat([positives, negatives], dim=1) / temperature
        labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample
        loss = F.cross_entropy(logits, labels, reduction='sum')
        
        return loss / (2 * N)

In [65]:
from model.SimCLR import simclr, create_backbone
import random

In [95]:
model = create_backbone(name='res18', num_classes=10)
classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True)
for name, value in model.named_parameters():
    if not name.startswith('linear') :
        value.requires_grad = False
pretrained_model = torch.load('../checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep1_rn100.ckpt', map_location='cpu')
model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False)

del pretrained_model
# model.add_module("Linear", classifier)
for name, value in model.named_parameters():
    print(name, value)

conv1.weight Parameter containing:
tensor([[[[-0.0183, -0.0765,  0.2014],
          [-0.0594, -0.0160, -0.0777],
          [-0.1222, -0.0284,  0.1164]],

         [[ 0.3198,  0.2831,  0.3139],
          [ 0.3030, -0.0359,  0.2107],
          [ 0.3482, -0.0047,  0.4345]],

         [[-0.0811, -0.0312, -0.0877],
          [-0.4400, -0.6757, -0.4297],
          [ 0.0313, -0.4326,  0.0159]]],


        [[[-0.0643,  0.2237,  0.4458],
          [-0.3193, -0.0773,  0.0561],
          [-0.4273, -0.0085,  0.1814]],

         [[-0.0256, -0.1148,  0.1742],
          [-0.4455, -0.1189,  0.2433],
          [-0.4229, -0.2738, -0.1874]],

         [[-0.0910,  0.0338,  0.2989],
          [-0.1994,  0.0990,  0.2593],
          [-0.0121,  0.1451, -0.0939]]],


        [[[-0.3403, -0.4584, -0.3110],
          [ 0.1626,  0.2688,  0.1822],
          [ 0.3948,  0.3835,  0.1284]],

         [[-0.4745, -0.1487, -0.1274],
          [-0.1165, -0.2197, -0.0441],
          [ 0.1622,  0.2059, -0.1887]],

         

In [93]:
def get_freezed_parameters(module):
    """
    Returns names of freezed parameters of the given module.
    """
    
    freezed_parameters = []
    for name, parameter in module.named_parameters():
        if not parameter.requires_grad:
            freezed_parameters.append(name)
            
    return freezed_parameters

get_freezed_parameters(model)
for i in filter(lambda p: p.requires_grad, model.parameters()):
    print(i)

Parameter containing:
tensor([[-0.0187, -0.0211, -0.0236,  ..., -0.0326,  0.0137,  0.0262],
        [-0.0381, -0.0086,  0.0255,  ..., -0.0020,  0.0384, -0.0152],
        [ 0.0052,  0.0015, -0.0284,  ...,  0.0357,  0.0110, -0.0346],
        ...,
        [-0.0243,  0.0382, -0.0128,  ...,  0.0228,  0.0274, -0.0390],
        [-0.0257,  0.0127,  0.0342,  ...,  0.0194,  0.0440, -0.0125],
        [-0.0125,  0.0293, -0.0360,  ..., -0.0290,  0.0309,  0.0352]],
       requires_grad=True)
Parameter containing:
tensor([-0.0302,  0.0121,  0.0225, -0.0034,  0.0179, -0.0157, -0.0225, -0.0358,
         0.0366, -0.0426], requires_grad=True)


In [94]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
optimizer.zero_grad()
output = model(torch.rand((1,3,32,32)))
loss = F.cross_entropy(output, torch.tensor([1]))
loss.backward()
optimizer.step()

In [99]:
'abcdfgh'.endswith('dfgh')

True