In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
output_dim = 1024

In [4]:
model = torchvision.models.resnet50(pretrained=True)



In [5]:
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = torch.nn.Identity()
model.fc = torch.nn.Linear(in_features=2048, out_features=output_dim)

In [6]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentu

In [7]:
torch.save(model.state_dict(), 'pretrained_encoder.pth')

In [8]:
!ls

encoder-v2.ipynb       main.py			 pretrained_encoder.pth
example_submission.py  ModelStealingPub.pt	 __pycache__
finetune_encoder.py    pretrained_encoder.ipynb  SimCLRTransform.py


# Dataset

In [9]:
from example_submission import TaskDataset

In [10]:
data = torch.load("ModelStealingPub.pt", weights_only=False)

In [11]:
print(data.transform)

None


In [12]:
data.__getitem__(0)

(73838, <PIL.Image.Image image mode=RGB size=32x32>, '40019202')

In [13]:
mean = [0.2980, 0.2962, 0.2987]
std = [0.2886, 0.2875, 0.2889]

In [14]:
import torchvision.transforms as transforms

In [15]:
data_transforms = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean, std)
])

In [16]:
data.transform = data_transforms

# Data loader

In [17]:
len(data)

13000

In [18]:
BATCH_SIZE = 64

In [19]:
data_loader = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

# One try

In [20]:
for batch in data_loader:
    if batch % 10 == 0:
        print(batch)

RuntimeError: output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]

In [21]:
pic = data.__getitem__(0)[1]

In [22]:
type(pic)

torch.Tensor

In [23]:
pic.shape

torch.Size([3, 32, 32])

In [24]:
pic_vec = pic.unsqueeze(0)

In [25]:
pic_vec.shape

torch.Size([1, 3, 32, 32])

In [26]:
feature_vector = model(pic_vec)

In [27]:
feature_vector

tensor([[ 0.1278, -0.0790,  0.1375,  ..., -0.4608,  0.1628,  0.5917]],
       grad_fn=<AddmmBackward0>)

In [28]:
feature_vector.shape

torch.Size([1, 1024])

# Simple Framework for Contrastive Learning 

## Random transform

In [29]:
class SimCLRTransform:
    def __init__(self):
        self.base_transform = transforms.Compose([
            transforms.Lambda(lambda img: img.convert("RGB")),
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=32),
            transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            transforms.GaussianBlur(kernel_size=3),
            transforms.ToTensor(), 
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        img1 = self.base_transform(img)
        img2 = self.base_transform(img)
        return img1, img2

## Data loader

In [30]:
data.transform = SimCLRTransform()
simclt_dataloader = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

## SimCLR model

In [31]:
class SimCLR(torch.nn.Module):
    def __init__(self, base_model, feature_dim=1024, projection_dim=128):
        super().__init__()
        
        self.encoder = base_model
        
        # Projection Head
        self.projection_head = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)  # 1024-features vector
        projections = self.projection_head(features)
        return features, projections

In [32]:
encoder = model
simclr = SimCLR(encoder)

## Loss

In [33]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    z = torch.cat([z_i, z_j], dim=0) 

    sim = torch.nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)

    labels = torch.arange(batch_size, device=z_i.device)
    labels = torch.cat([labels, labels], dim=0)

    sim /= temperature
    loss = torch.nn.functional.cross_entropy(sim, labels)
    return loss

## Optimizer

In [34]:
optimizer = torch.optim.Adam(simclr.parameters(), lr=3e-4, weight_decay=1e-4)

## Train loop

In [36]:
for epoch in range(10):
    print("====")
    for idx, (img1, img2), label in simclt_dataloader:
        _, z_i = simclr(img1)
        _, z_j = simclr(img2)
        
        loss = nt_xent_loss(z_i, z_j)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

====
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 
 →→→idx←←← 



KeyboardInterrupt



In [61]:
img1, img2 = data.__getitem__(0)[1]

In [62]:
img1, img2

(tensor([[[0.4757, 0.4893, 0.4757,  ..., 0.3806, 0.4350, 0.4757],
          [0.4350, 0.4621, 0.4485,  ..., 0.3806, 0.4350, 0.4757],
          [0.3534, 0.3670, 0.3942,  ..., 0.3806, 0.4214, 0.4621],
          ...,
          [0.3806, 0.4078, 0.4350,  ..., 0.5029, 0.4757, 0.4485],
          [0.4078, 0.4350, 0.4621,  ..., 0.5029, 0.4621, 0.4350],
          [0.4214, 0.4350, 0.4757,  ..., 0.5029, 0.4485, 0.4350]],
 
         [[0.2656, 0.2792, 0.2656,  ..., 0.1564, 0.1974, 0.2519],
          [0.2246, 0.2519, 0.2383,  ..., 0.1428, 0.1974, 0.2383],
          [0.1292, 0.1564, 0.1837,  ..., 0.1428, 0.1837, 0.2246],
          ...,
          [0.1564, 0.1701, 0.2110,  ..., 0.2792, 0.2519, 0.2246],
          [0.1837, 0.1974, 0.2383,  ..., 0.2792, 0.2383, 0.2246],
          [0.1974, 0.2246, 0.2519,  ..., 0.2928, 0.2383, 0.2246]],
 
         [[0.3099, 0.3235, 0.3099,  ..., 0.2013, 0.2285, 0.2692],
          [0.2692, 0.2963, 0.2828,  ..., 0.2013, 0.2285, 0.2692],
          [0.1878, 0.2149, 0.2285,  ...,

In [64]:
id_, (img1, img2), label = dataset[0]  # Get a single sample

In [65]:
print(img1.shape)  # Expecting torch.Size([3, 32, 32])
print(img2.shape)

torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
