In [2]:
torch.save(model,"resnet18_simclr_cifar100_centered.pt")

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100
import numpy as np
from alexnet_cifar import *
import os, random


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


class SimCLRTransform:
    def __init__(self, size):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


class CIFAR100SimCLR(Dataset):
    def __init__(self, root='./data', train=True, transform=None):
        self.dataset = CIFAR100(root=root, train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        return img
    
def nt_xent_loss(z_i, z_j, temperature):
    """
    Calculates the NT-Xent loss.
    z_i, z_j are the representations of two augmentations of the same image, 
    and should be normalized.
    """
    batch_size = z_i.size(0)

    z = torch.cat((z_i, z_j), dim=0)
    sim_matrix = torch.exp(torch.mm(z, z.T) / temperature)

    mask = torch.eye(batch_size, dtype=torch.bool).to(z.device)
    mask = mask.repeat(2, 2)
    sim_matrix = sim_matrix.masked_select(~mask).view(2 * batch_size, -1)

    positives = torch.exp(torch.sum(z_i * z_j, dim=-1) / temperature).repeat(2)
    negatives = sim_matrix.sum(dim=-1)

    loss = -torch.log(positives / negatives).mean()
    return loss

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
set_seed(0)

simclr_transform = SimCLRTransform(32) # cifar 
train_dataset = CIFAR100SimCLR(train=True, transform=simclr_transform)
val_dataset =  CIFAR100SimCLR(train=False, transform=simclr_transform)
train_loader = DataLoader(train_dataset, batch_size=1024*4, shuffle=True, num_workers=0)
val_loader =  DataLoader(val_dataset, batch_size=5000, shuffle=False, num_workers=0)

num_filters = [64, 192, 384, 256, 256, 4096, 4096]  # Example filter numbers for each layer
model = AlexNet_CIFAR_NoFC(num_filters).to(device)
projection_head = ProjectionHead(input_dim=num_filters[-1], hidden_dim=512, output_dim=128).to(device)

print(count_parameters(model))

def train(train_loader, model, projection_head, optimizer, scheduler, temperature=0.15, epochs=10):
    for epoch in range(epochs):
        for (images1, images2) in train_loader:
            model.train()
            # Concatenate the images from the two augmentations
            images = torch.cat([images1, images2], dim=0)
            images = images.to(device)

            optimizer.zero_grad()

            features = model(images)
            projections = projection_head(features)
            projections = F.normalize(projections, dim=1)

            loss = nt_xent_loss(projections[:len(images)//2], projections[len(images)//2:], temperature)
            
            loss.backward()
            optimizer.step()
            
        # Validation
        val_loss_list = []
        for (images1, images2) in val_loader:
            model.eval()
            # Concatenate the images from the two augmentations
            images = torch.cat([images1, images2], dim=0)
            images = images.to(device)

            features = model(images)
            projections = projection_head(features)
            projections = F.normalize(projections, dim=1)

            val_loss = nt_xent_loss(projections[:len(images)//2], projections[len(images)//2:], temperature)
            val_loss_list.append(val_loss.item())
        scheduler.step(np.mean(val_loss_list))
            
            # Print loss (or log it)
        
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {loss.item():.4f}, Val Loss: {np.mean(val_loss_list):.4f}')

optimizer = torch.optim.AdamW(list(model.parameters()) + list(projection_head.parameters()), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
train(train_loader, model, projection_head, optimizer, scheduler)

# torch.save(model,"resnet18_simclr_cifar100.pt")

cuda
Random seed set as 0
Files already downloaded and verified
Files already downloaded and verified
23231296
Epoch [1/10], Train Loss: 7.4259, Val Loss: 9.1901
Epoch [2/10], Train Loss: 7.3744, Val Loss: 9.1315
Epoch [3/10], Train Loss: 7.2775, Val Loss: 9.1407
Epoch [4/10], Train Loss: 7.3240, Val Loss: 9.0635
Epoch [5/10], Train Loss: 7.2312, Val Loss: 9.0311
Epoch [6/10], Train Loss: 7.2745, Val Loss: 9.0366
Epoch [7/10], Train Loss: 7.2476, Val Loss: 9.0227
Epoch [8/10], Train Loss: 7.3028, Val Loss: 9.0539
Epoch [9/10], Train Loss: 7.2232, Val Loss: 9.0059
Epoch [10/10], Train Loss: 7.2513, Val Loss: 9.0352


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100
import numpy as np
from alexnet_cifar import *
import os, random

filter_list = np.array([64, 192, 384, 256, 256, 4096, 4096])
for i in list(np.geomspace(1, 32, 40)):
    num_filters = (filter_list / i).astype(int)
    print(num_filters)
    model = AlexNet_CIFAR_NoFC(num_filters)
    
    print("num parameters:", count_parameters(model))

[  64  192  384  256  256 4096 4096]
num parameters: 23231296
[  58  175  351  234  234 3747 3747]
num parameters: 19433498
[  53  160  321  214  214 3429 3429]
num parameters: 16271486
[  49  147  294  196  196 3137 3137]
num parameters: 13626805
[  44  134  269  179  179 2870 2870]
num parameters: 11398759
[  41  123  246  164  164 2626 2626]
num parameters: 9548498
[  37  112  225  150  150 2403 2403]
num parameters: 7993034
[  34  103  206  137  137 2198 2198]
num parameters: 6687038
[  31   94  188  125  125 2011 2011]
num parameters: 5592442
[  28   86  172  115  115 1840 1840]
num parameters: 4688797
[  26   78  157  105  105 1684 1684]
num parameters: 3923733
[  24   72  144   96   96 1541 1541]
num parameters: 3286811
[  22   66  132   88   88 1410 1410]
num parameters: 2753946
[  20   60  120   80   80 1290 1290]
num parameters: 2299980
[  18   55  110   73   73 1180 1180]
num parameters: 1923726
[  16   50  101   67   67 1080 1080]
num parameters: 1612687
[ 15  46  92  61  6

In [22]:
filter_list = np.array([64,64,128,256,512])

for i in list(np.geomspace(1, 32, 40)):
    num_filters = (filter_list / i).astype(int)
    print(num_filters)
    model = ResNet18_NoFC(BasicBlock, [2, 2, 2, 2], num_filters)
    print("num parameters:", count_parameters(model))

[ 64  64 128 256 512]
num parameters: 11176512
[ 58  58 117 234 468]
num parameters: 9336823
[ 53  53 107 214 428]
num parameters: 7810128
[ 49  49  98 196 392]
num parameters: 6554877
[ 44  44  89 179 358]
num parameters: 5461556
[ 41  41  82 164 328]
num parameters: 4591221
[ 37  37  75 150 300]
num parameters: 3839968
[ 34  34  68 137 274]
num parameters: 3202067
[ 31  31  62 125 251]
num parameters: 2681003
[ 28  28  57 115 230]
num parameters: 2255748
[ 26  26  52 105 210]
num parameters: 1882091
[ 24  24  48  96 192]
num parameters: 1576152
[ 22  22  44  88 176]
num parameters: 1324950
[ 20  20  40  80 161]
num parameters: 1105017
[ 18  18  36  73 147]
num parameters: 919438
[ 16  16  33  67 135]
num parameters: 774599
[ 15  15  30  61 123]
num parameters: 643675
[ 14  14  28  56 113]
num parameters: 544707
[ 12  12  25  51 103]
num parameters: 450683
[11 11 23 47 94]
num parameters: 377741
[10 10 21 43 86]
num parameters: 316302
[ 9  9 19 39 79]
num parameters: 264950
[ 9  9 18 

In [24]:
seed=0
torch.save(model,f"/home/mila/p/pingsheng.li/scratch/models/resnet18_simclr_cifar100_parameters{count_parameters(model)}_seed{seed}.pt")