In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
output_dim = 1024

In [4]:
model = torchvision.models.resnet50(pretrained=True)



In [5]:
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = torch.nn.Identity()
model.fc = torch.nn.Linear(in_features=2048, out_features=output_dim)

In [6]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentu

In [34]:
torch.save(model.state_dict(), 'pretrained_encoder.pth')

In [35]:
!ls

example_submission.py  pretrained_encoder.ipynb
ModelStealingPub.pt    pretrained_encoder.pth


# Dataset

In [7]:
from example_submission import TaskDataset

In [8]:
data = torch.load("ModelStealingPub.pt", weights_only=False)

In [9]:
print(data.transform)

None


In [21]:
data.__getitem__(0)

(73838, <PIL.Image.Image image mode=RGB size=32x32>, '40019202')

In [10]:
mean = [0.2980, 0.2962, 0.2987]
std = [0.2886, 0.2875, 0.2889]

In [11]:
import torchvision.transforms as transforms

In [12]:
data_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

In [13]:
data.transform = data_transforms

# Data loader

In [14]:
len(data)

13000

In [15]:
BATCH_SIZE = 64
dataset = data

In [16]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# One try

In [18]:
pic = data.__getitem__(0)[1]

In [19]:
type(pic)

torch.Tensor

In [21]:
pic.shape

torch.Size([3, 32, 32])

In [23]:
pic_vec = pic.unsqueeze(0)

In [24]:
pic_vec.shape

torch.Size([1, 3, 32, 32])

In [26]:
feature_vector = model(pic_vec)

In [27]:
feature_vector

tensor([[ 0.2537,  0.2830, -0.2327,  ...,  0.0889, -0.3282,  0.0436]],
       grad_fn=<AddmmBackward0>)

In [28]:
feature_vector.shape

torch.Size([1, 1024])

# Simple Framework for Contrastive Learning 

## Random transform

In [30]:
class SimCLRTransform:
    def __init__(self):
        self.base_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=32),
            transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            transforms.GaussianBlur(kernel_size=3),
            transforms.RandomGreyscale(kernel_size=3),
            transforms.ToTensor(), 
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        img1 = self.base_transform(img)
        img2 = self.base_transform(img)
        return img1, img2

## Data loader

In [32]:
data.transform = SimCLRTransform
simclt_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## SimCLR model

In [33]:
class SimCLR(torch.nn.Module):
    def __init__(self, base_model, feature_dim=1024, projection_dim=128):
        super().__init__()
        
        self.encoder = base_model
        
        # Projection Head
        self.projection_head = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)  # 1024-features vector
        projections = self.projection_head(features)
        return features, projections

In [36]:
encoder = model
simclr = SimCLR(encoder)