In [1]:
import torch
import torchvision
import numpy as np
from simclr import SimCLR
from simclr.modules import get_resnet, NT_Xent
from simclr.modules.transformations import TransformsSimCLR

In [2]:
device= torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

#### Setup Hyperparameters, can change for tuning

In [3]:
args = {
    'batch_size': 64,
    'epochs': 100,
    'image_size': 224,
    'resnet': 'resnet18',
    'temperature': 0.5,
    'weight_decay': 5e-5,
    'learning_rate': 1e-5,
    'dataset_dir': '../../data/'
}

#### Training Data Loading/Processing

In [6]:
# args['dataset'] = 'STL10'
# dataset = torchvision.datasets.STL10(args['dataset_dir'], split='unlabeled', download=True, transform=TransformsSimCLR(size=args['image_size']))
args['dataset'] = 'CIFAR100'
dataset = torchvision.datasets.CIFAR100(args['dataset_dir'], download=True, transform=TransformsSimCLR(size=args['image_size']))

Files already downloaded and verified


In [17]:
training_loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], shuffle=True)

#### Model Setup

In [7]:
encoder = get_resnet(args['resnet'])
model = SimCLR(encoder, 64, encoder.fc.in_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), args['learning_rate'])



In [8]:
loss_func = NT_Xent(args['batch_size'], args['temperature'], world_size=1)

#### Training

In [16]:
def train(model, optimizer, loss_func, training_loader):
    total_loss = 0
    for step, ((xi, xj), _) in enumerate(training_loader):
        optimizer.zero_grad()
        #need to make the tensor the right type or it will fail
        xi = xi.to(device)
        xj = xj.to(device)
        #xi, xj are the two correlated augmented examples
        #first two are the represenatations, and try to maximize z agreements
        _, _, zi, zj = model(xi, xj)
        if zi.size()[0] == 64:
            loss = loss_func(zi, zj)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if step % 130 == 0:
                print(f'Step {step} Loss: {loss.item()}')
    return total_loss

In [9]:
for epoch in range(10):
    epoch_loss = train(model, optimizer, loss_func, training_loader)
    print(f"Epoch {epoch} Loss: {epoch_loss}")

Step 0 Loss: 4.815101623535156
Step 130 Loss: 4.661430358886719
Step 260 Loss: 4.452215671539307
Step 390 Loss: 4.550997734069824
Step 520 Loss: 4.545881271362305
Step 650 Loss: 4.3085126876831055
Step 780 Loss: 4.443071365356445
Epoch 0 Loss: 3480.9206972122192
Step 0 Loss: 4.418758392333984
Step 130 Loss: 4.337266445159912
Step 260 Loss: 4.290072917938232
Step 390 Loss: 4.3134660720825195
Step 520 Loss: 4.418010711669922
Step 650 Loss: 4.3648247718811035
Step 780 Loss: 4.3372931480407715
Epoch 1 Loss: 3355.4564394950867
Step 0 Loss: 4.23141622543335
Step 130 Loss: 4.1005425453186035
Step 260 Loss: 4.030501365661621
Step 390 Loss: 4.066650867462158
Step 520 Loss: 4.0562920570373535
Step 650 Loss: 4.0688371658325195
Step 780 Loss: 4.1395487785339355
Epoch 2 Loss: 3290.510726213455
Step 0 Loss: 4.162351608276367
Step 130 Loss: 4.160772800445557
Step 260 Loss: 4.085745811462402
Step 390 Loss: 4.122556209564209
Step 520 Loss: 4.175682067871094
Step 650 Loss: 4.183428764343262
Step 780 Los

#### Save the model weights

In [15]:
torch.save(model.state_dict(), 'simclr.tar')

In [8]:
model.load_state_dict(torch.load('simclr.tar'))
model.to(device)

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

#### Pretrained model from https://github.com/Spijkervet/SimCLR/releases/download/1.2/ (much better given 100 epochs of training)

In [4]:
resnet = get_resnet('resnet50', pretrained=False)
model2 = SimCLR(resnet, 64, resnet.fc.in_features)
model2.load_state_dict(torch.load('checkpoint_100.tar'), strict=False)
model2.to(device)



SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

#### Data setup for evaluation step

In [32]:
training_dataset = torchvision.datasets.CIFAR100(args['dataset_dir'], download=True, transform=TransformsSimCLR(size=args['image_size']).test_transform, train=True)
N = len(training_dataset)
dataset_pct = 0.5 #change this for percent of data to train on for efficient training metric
num_train_samples = int(N * dataset_pct)
dataset_indices = np.random.choice(N, num_train_samples, replace=False)
dataset_subset = torch.utils.data.Subset(training_dataset, dataset_indices)

# split into train/val
N_subset = len(dataset_subset)
V = int(num_train_samples * 0.2) # 20% validation
dataset_train, dataset_val = torch.utils.data.random_split(dataset_subset, [N_subset - V, V])
test_dataset = torchvision.datasets.CIFAR100(args['dataset_dir'], download=True, transform=TransformsSimCLR(size=args['image_size']).test_transform, train=False)

Files already downloaded and verified
Files already downloaded and verified


In [33]:
training_loader= torch.utils.data.DataLoader(dataset_train, batch_size=args['batch_size'], shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset_val, batch_size=args['batch_size'], shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

#### Setup Logistic Regression for evaluation (we may do something differnt to better match the CPC framework)

In [34]:
class LogisticModel(torch.nn.Module):
    def __init__(self, features, classes):
        super(LogisticModel, self).__init__()
        self.model = torch.nn.Linear(features, classes)
    def forward(self, x):
        return self.model(x)

In [35]:
# change the number of features to 10 or 100 depending on CIFAR
logistic_model = LogisticModel(model2.n_features, 100).to(device) #model is the simclr as defined above
optimizer= torch.optim.Adam(logistic_model.parameters(), lr=1e-4)
loss_func = torch.nn.CrossEntropyLoss()

#### Encoder to label the data with their represenations

In [36]:
def encoder(loader, simclr):
    features = []
    labels = []

    for _, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            xi, _, _, _ = simclr(x, x)
        
        xi = xi.detach()
        features.extend(xi.cpu().detach().numpy())
        labels.extend(y.numpy())
    features = np.array(features)
    labels = np.array(labels)
    return features, labels

def populate_labels(simclr, training_loader, validation_loader, test_loader):
    X_train, y_train = encoder(training_loader, simclr)
    X_val, y_val = encoder(validation_loader, simclr)
    X_test, y_test = encoder(test_loader, simclr)
    return X_train, y_train, X_val, y_val, X_test, y_test

In [37]:
def loaders(X_train, y_train, X_val, y_val, X_test, y_test, batch):
    training = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
    validation = torch.utils.data.TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
    testing = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))

    training_loader_final = torch.utils.data.DataLoader(training, batch_size=batch, shuffle=False)
    validation_loader_final = torch.utils.data.DataLoader(validation, batch_size=batch, shuffle=False)
    test_loader_final = torch.utils.data.DataLoader(testing, batch_size=batch, shuffle=False)
    return training_loader_final, validation_loader_final, test_loader_final

In [38]:
(X_train, y_train, X_val, y_val, X_test, y_test) = populate_labels(model2, training_loader, validation_loader, test_loader)
training_loader_final, validation_loader_final, test_loader_final = loaders(X_train, y_train, X_val, y_val, X_test, y_test, args['batch_size'])

#### Training on patched data

In [39]:
def train(logistic, loss_func, optimizer, data):
    overall_accuracy = 0
    total_loss = 0
    for _, (x, y) in enumerate(data):
        optimizer.zero_grad()
        x =x.to(device)
        y =y.to(device)

        out = logistic(x)
        loss = loss_func(out, y)

        guess = out.argmax(1)
        accuracy = (guess==y).sum().item() /  y.size(0)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        overall_accuracy += accuracy

    return total_loss, overall_accuracy


In [40]:
def validation(logistic, data):
    overall_accuracy = 0
    with torch.no_grad():
        for _, (x, y) in enumerate(data):
            x =x.to(device)
            y =y.to(device)

            out = logistic(x)
            guess = out.argmax(1)
            accuracy = (guess==y).sum().item() /  y.size(0)
            overall_accuracy += accuracy

    return overall_accuracy

In [41]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(logistic_model.parameters(), lr=1e-3)

In [42]:
random_seed = 15009
torch.manual_seed(random_seed)
np.random.seed(random_seed)
for epoch in range(50):
    epoch_loss, epoch_accuracy = train(logistic_model, loss_func, optimizer, training_loader_final)
    print(f"Epoch {epoch} Loss: {epoch_loss/len(training_loader_final)} Accuracy: {epoch_accuracy/len(training_loader_final)}")
    validation_accuracy = validation(logistic_model, validation_loader_final)
    print(f"Epoch {epoch} Validation Accuracy: {validation_accuracy/len(validation_loader_final)}")

Epoch 0 Loss: 3.733852380380844 Accuracy: 0.16164137380191693
Epoch 0 Validation Accuracy: 0.18947784810126583
Epoch 1 Loss: 3.3106722626061487 Accuracy: 0.2248402555910543
Epoch 1 Validation Accuracy: 0.2142009493670886
Epoch 2 Loss: 3.1472154600551714 Accuracy: 0.25324480830670926
Epoch 2 Validation Accuracy: 0.22626582278481014
Epoch 3 Loss: 3.0287524366531127 Accuracy: 0.2744608626198083
Epoch 3 Validation Accuracy: 0.23516613924050633
Epoch 4 Loss: 2.934362782457004 Accuracy: 0.29018570287539935
Epoch 4 Validation Accuracy: 0.24347310126582278
Epoch 5 Loss: 2.855410454753108 Accuracy: 0.30501198083067094
Epoch 5 Validation Accuracy: 0.25197784810126583
Epoch 6 Loss: 2.787312414699469 Accuracy: 0.3176916932907348
Epoch 6 Validation Accuracy: 0.2614715189873418
Epoch 7 Loss: 2.727297328912412 Accuracy: 0.32992212460063897
Epoch 7 Validation Accuracy: 0.26542721518987344
Epoch 8 Loss: 2.673549623154223 Accuracy: 0.3404552715654952
Epoch 8 Validation Accuracy: 0.270371835443038
Epoch 

#### Final Testing

In [22]:
def final_eval(logistic, data):
    overall_accuracy = 0
    total_loss = 0
    logistic.eval()
    for _, (x,y) in enumerate(data):
        x= x.to(device)
        y = y.to(device)

        out = logistic(x)
        loss = loss_func(out, y)
        guess = out.argmax(1)
        accuracy = (guess==y).sum().item() /  y.size(0)

        total_loss += loss.item()
        overall_accuracy += accuracy
    return total_loss, overall_accuracy

In [23]:
random_seed = 15009
torch.manual_seed(random_seed)
np.random.seed(random_seed)
loss, accuracy = final_eval(logistic_model, test_loader_final)
print( f'Loss: {loss/len(test_loader_final)}, Accuracy: {accuracy/len(test_loader_final)}')

Loss: 2.245242929762336, Accuracy: 0.4211783439490446
