In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import Dataset, DataLoader

from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, resnet34, resnet50
import torchvision.transforms as transforms
from torchsummary import summary

import numpy as np
import time

from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from medmnist import PathMNIST

from eqCLR.eq_resnet import EqResNet18



In [2]:

###################### PARAMS ##############################

BACKBONE = "resnet18"

BATCH_SIZE = 512
N_EPOCHS = 100 # 1000
N_CPU_WORKERS = 16
BASE_LR = 0.03         # important
WEIGHT_DECAY = 5e-4    # important
MOMENTUM = 0.9
PROJECTOR_HIDDEN_SIZE = 1024
PROJECTOR_OUTPUT_SIZE = 128
CROP_LOW_SCALE = 0.2
GRAYSCALE_PROB = 0.1   # important
PRINT_EVERY_EPOCHS = 5

MODEL_FILENAME = f"path_mnist-{BACKBONE}_wo_rotation-{np.random.randint(10000):04}.pt"


In [3]:

###################### DATA LOADER #########################

pmnist_train = PathMNIST(split='train', download=False, size=28, root='data/pathmnist/', transform=transforms.ToTensor())
pmnist_test = PathMNIST(split='test', download=False, size=28, root='data/pathmnist/', transform=transforms.ToTensor())

print("Data loaded.")

# additional rotation
class RandomRightAngleRotation:
    """Randomly rotate PIL image by 90, 180, or 270 degrees."""
    def __call__(self, x):
        angle = int(torch.randint(1, 4, ()).item()) * 90
        return x.rotate(angle)

transforms_ssl = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=32, scale=(CROP_LOW_SCALE, 1)),
        # RandomRightAngleRotation(), # additional rotation
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply(
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
        ),
        transforms.RandomGrayscale(p=GRAYSCALE_PROB),
        transforms.ToTensor(), # NB: runtime faster when this line is last
    ]
)

class PairedTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return (self.transform(x), self.transform(x))


paired_ssl_transforms = PairedTransform(transforms_ssl)

pmnist_train_ssl = PathMNIST(split='train', download=False, size=28, root='data/pathmnist/', transform=paired_ssl_transforms)

pmnist_loader_ssl = DataLoader(
    pmnist_train_ssl,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=N_CPU_WORKERS,
    pin_memory=True,
)

###################### NETWORK ARCHITECTURE #########################

class ResNetwithProjector(nn.Module):
    def __init__(self, backbone_network):
        super().__init__()

        self.backbone = backbone_network(weights=None)
        self.backbone_output_dim = self.backbone.fc.in_features
        
        self.backbone.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.backbone.maxpool = nn.Identity()
        self.backbone.fc = nn.Identity()

        self.projector = nn.Sequential(
            nn.Linear(self.backbone_output_dim, PROJECTOR_HIDDEN_SIZE), 
            nn.ReLU(), 
            nn.Linear(PROJECTOR_HIDDEN_SIZE, PROJECTOR_OUTPUT_SIZE),
        )

    def forward(self, x):
        h = self.backbone(x)
        z = self.projector(h)
        return h, z


def infoNCE(features, temperature=0.5):
    x = F.normalize(features)
    cos_xx = x @ x.T / temperature
    cos_xx.fill_diagonal_(float("-inf"))
    
    batch_size = cos_xx.size(0) // 2
    targets = torch.arange(batch_size * 2, dtype=int, device=cos_xx.device)
    targets[:batch_size] += batch_size
    targets[batch_size:] -= batch_size

    return F.cross_entropy(cos_xx, targets)

backbones = {
   "resnet18": resnet18,    # backbone_output_dim = 512
   "resnet34": resnet34,    # backbone_output_dim = 512
   "resnet50": resnet50,    # backbone_output_dim = 2048
}


Data loaded.


# SimCLR

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_resnet18 = ResNetwithProjector(backbones[BACKBONE]).to(device)
model_resnet18_wo_rotation = ResNetwithProjector(backbones[BACKBONE]).to(device)

In [None]:
summary(model_resnet18, input_size=(3, 28, 28),  device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]           1,728
       BatchNorm2d-2           [-1, 64, 28, 28]             128
              ReLU-3           [-1, 64, 28, 28]               0
          Identity-4           [-1, 64, 28, 28]               0
            Conv2d-5           [-1, 64, 28, 28]          36,864
       BatchNorm2d-6           [-1, 64, 28, 28]             128
              ReLU-7           [-1, 64, 28, 28]               0
            Conv2d-8           [-1, 64, 28, 28]          36,864
       BatchNorm2d-9           [-1, 64, 28, 28]             128
             ReLU-10           [-1, 64, 28, 28]               0
       BasicBlock-11           [-1, 64, 28, 28]               0
           Conv2d-12           [-1, 64, 28, 28]          36,864
      BatchNorm2d-13           [-1, 64, 28, 28]             128
             ReLU-14           [-1, 64,

In [None]:
print(model_resnet18)

ResNetwithProjector(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Identity()
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)

In [25]:
model_resnet18.load_state_dict(torch.load('path_mnist-resnet18-1146.pt', weights_only=True))
model_resnet18_wo_rotation.load_state_dict(torch.load('path_mnist-resnet18_wo_rotation-7711.pt', weights_only=True))

<All keys matched successfully>

# EqCLR

In [12]:
model_eq = EqResNet18(N=4).to(device)

summary(model_eq, input_size=(3, 28, 28),  device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
SingleBlockBasisExpansion-1             [-1, 4, 1, 49]               0
BlocksBasisExpansion-2                [-1, 3, 49]               0
            R2Conv-3          [-1, 256, 14, 14]              64
       BatchNorm3d-4        [-1, 64, 4, 14, 14]             128
    InnerBatchNorm-5          [-1, 256, 14, 14]               0
              ReLU-6          [-1, 256, 14, 14]               0
PointwiseMaxPool2D-7            [-1, 256, 7, 7]               0
SingleBlockBasisExpansion-8              [-1, 4, 4, 9]               0
SingleBlockBasisExpansion-9              [-1, 4, 4, 9]               0
SingleBlockBasisExpansion-10              [-1, 4, 4, 9]               0
SingleBlockBasisExpansion-11              [-1, 4, 4, 9]               0
SingleBlockBasisExpansion-12              [-1, 4, 4, 9]               0
SingleBlockBasisExpansion-13              [-1, 4, 4, 9] 

In [9]:
print(model_eq)

EqResNet18(
  (conv1): R2Conv([C4_on_R2[(None, 4)]: {irrep_0 (x3)}(3)], [C4_on_R2[(None, 4)]: {regular (x64)}(256)], kernel_size=7, stride=2, padding=3)
  (bn1): InnerBatchNorm([C4_on_R2[(None, 4)]: {regular (x64)}(256)], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=False, type=[C4_on_R2[(None, 4)]: {regular (x64)}(256)])
  (maxpool): PointwiseMaxPool2D()
  (layer1): SequentialModule(
    (0): EqBasicBlock(
      (conv1): R2Conv([C4_on_R2[(None, 4)]: {regular (x64)}(256)], [C4_on_R2[(None, 4)]: {regular (x64)}(256)], kernel_size=3, stride=1, padding=1, bias=False)
      (bn1): InnerBatchNorm([C4_on_R2[(None, 4)]: {regular (x64)}(256)], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=False, type=[C4_on_R2[(None, 4)]: {regular (x64)}(256)])
      (conv2): R2Conv([C4_on_R2[(None, 4)]: {regular (x64)}(256)], [C4_on_R2[(None, 4)]: {regular (x64)}(256)], kernel_size=3, stride=1, padding=1, bias=False)
      

In [26]:
model_eq.load_state_dict(torch.load('path_mnist-eqCLR_resnet18_wo_rotation-0319.pt', weights_only=True))

RuntimeError: Error(s) in loading state_dict for EqResNet18:
	Missing key(s) in state_dict: "conv1.expanded_bias", "conv1.filter", "layer1.0.conv1.filter", "layer1.0.conv2.filter", "layer1.1.conv1.filter", "layer1.1.conv2.filter", "layer2.0.conv1.filter", "layer2.0.conv2.filter", "layer2.0.downsample.filter", "layer2.1.conv1.filter", "layer2.1.conv2.filter", "layer3.0.conv1.filter", "layer3.0.conv2.filter", "layer3.0.downsample.filter", "layer3.1.conv1.filter", "layer3.1.conv2.filter", "layer4.0.conv1.filter", "layer4.0.conv2.filter", "layer4.0.downsample.filter", "layer4.1.conv1.filter", "layer4.1.conv2.filter". 
	size mismatch for fully_net.2.weight: copying a param with shape torch.Size([9, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for fully_net.2.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([128]).

# Evaluation

In [32]:
def dataset_to_X_y(dataset, model):
    X = []
    y = []
    Z = []

    for batch_idx, batch in enumerate(DataLoader(dataset, batch_size=1024)):
        images, labels = batch

        h, z = model(images.to(device))

        X.append(h.cpu().numpy())
        Z.append(z.cpu().numpy())
        y.append(labels.cpu().numpy().ravel())

    X = np.vstack(X)
    Z = np.vstack(Z)
    y = np.hstack(y)

    return X, y, Z

In [38]:
def eval_knn(X_train, y_train, X_test, y_test):
    eval_dict = {}

    for k in [1, 5, 10]:
        for metric in ["euclidean", "cosine"]:
            knn = KNeighborsClassifier(n_neighbors=k, metric=metric, n_jobs=-1)
            knn.fit(X_train, y_train)
            acc = knn.score(X_test, y_test)
            eval_dict[metric] = {k: acc}
            print(f"KNN (k={k}, metric={metric}): {acc*100:.2f}%")

    # Logistic Regression
    logreg = LogisticRegression(max_iter=1000, n_jobs=-1)
    logreg.fit(X_train, y_train)
    log_reg = logreg.score(X_test, y_test)
    eval_dict["logistic_regression"] = log_reg
    print(f"Logistic Regression: {log_reg*100:.2f}%")

    lin = LogisticRegression(penalty=None, solver="saga")
    lin.fit(X_train, y_train)
    lin_acc = lin.score(X_test, y_test)
    eval_dict["linear_accuracy"] = lin_acc
    print(f"Linear accuracy (sklearn): {lin_acc}", flush=True)

    return eval_dict

In [41]:
def lin_eval_rep(X_train, y_train, X_test, y_test, n_epochs=500, adam_lr=0.01):

    X_train = torch.tensor(X_train, device=device)
    X_test = torch.tensor(X_test, device=device)
    y_train = torch.tensor(y_train, device=device)
    y_test = torch.tensor(y_test, device=device)

    classifier = nn.Linear(X_train.shape[1], 10)
    classifier.to(device)
    classifier.train()

    optimizer = Adam(classifier.parameters(), lr=adam_lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

    for epoch in range(n_epochs):
        N = len(X_train)
        perm = torch.randperm(N)
        perm = perm[:N - (N % 1000)]              # drop remainder
        batches = perm.view(-1, 1000)        
        for idx in batches:
            optimizer.zero_grad()
            logits = classifier(X_train[idx])
            loss = F.cross_entropy(logits, y_train[idx])
            loss.backward()
            optimizer.step()
        scheduler.step()

    classifier.eval()
    with torch.no_grad():
        yhat = classifier(X_test)

    acc = (yhat.argmax(axis=1) == y_test).cpu().numpy().mean()
    print(f"Linear accuracy (Adam on precomputed representations): {acc}", flush=True)

    return acc

In [54]:
def lin_eval_aug(loader_classifier, model, n_classes, n_epochs=100, adam_lr=0.01, adam_wd=5e-6):
    classifier = nn.Linear(model.backbone_output_dim, n_classes)
    for param in model.backbone.parameters():
        param.requires_grad = False

    optimizer = Adam(classifier.parameters(), lr=adam_lr, weight_decay=adam_wd)
    scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

    classifier.to(device)
    classifier.train()
    training_start_time = time.time()

    for epoch in range(n_epochs):
        epoch_loss = 0.0
        start_time = time.time()

        for batch_idx, batch in enumerate(loader_classifier):
            view, y = batch

            optimizer.zero_grad()

            h, _ = model(view.to(device))
            logits = classifier(h)
            loss = F.cross_entropy(logits, y.to(device).squeeze().long())
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

        end_time = time.time()
        if (epoch + 1) % PRINT_EVERY_EPOCHS == 0:
            print(
                f"Epoch {epoch + 1}, "
                f"average loss {epoch_loss / len(loader_classifier):.4f}, "
                f"{end_time - start_time:.1f} s",
                flush=True
            )

    scheduler.step()

    training_end_time = time.time()
    hours = (training_end_time - training_start_time) / 60 // 60
    minutes = (training_end_time - training_start_time) / 60 % 60
    print(
        f"Total classifier training length for {n_epochs} epochs: {hours:.0f}h {minutes:.0f}min",
        flush=True
    )

    classifier.eval()
    with torch.no_grad():
        yhat = []
        y = []

        for batch_idx, batch in enumerate(DataLoader(pmnist_test, batch_size=1024)):
            images, labels = batch

            h, _ = model(images.to(device))
            logits = classifier(h)

            yhat.append(logits.cpu().numpy())
            y.append(labels)

        yhat = np.vstack(yhat)
        y = np.hstack(y)

    acc = (yhat.argmax(axis=1) == y).mean()
    print(f"Linear accuracy (trained with augmentations): {acc}", flush=True)

    return acc

## SimCLR

### With rotation (default)

In [None]:
with torch.no_grad():
    X_train, y_train, Z_train = dataset_to_X_y(pmnist_train, model_resnet18)
    X_test, y_test, Z_test = dataset_to_X_y(pmnist_test, model_resnet18)

In [None]:
resnet_18_eval = eval_knn(X_train, y_train, X_test, y_test)

KNN (k=1, metric=euclidean): 84.53%
KNN (k=1, metric=cosine): 84.48%
KNN (k=5, metric=euclidean): 87.06%
KNN (k=5, metric=cosine): 86.96%
KNN (k=10, metric=euclidean): 87.58%
KNN (k=10, metric=cosine): 87.30%
Logistic Regression: 90.56%
Linear accuracy (sklearn): 0.9036211699164345




In [42]:
evaaal = lin_eval_rep(X_train, y_train, X_test, y_test)

Linear accuracy (Adam on precomputed representations): 0.8998607242339833


In [45]:
transforms_classifier = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=32, scale=(CROP_LOW_SCALE, 1)),
        transforms.RandomHorizontalFlip(),
        RandomRightAngleRotation(), # additional rotation
        transforms.ToTensor(),
    ]
)

pmnist_train_classifier = PathMNIST(split='train', download=False, size=28, root='data/pathmnist/', transform=transforms_classifier)


pmnist_loader_classifier = DataLoader(
    pmnist_train_classifier,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=N_CPU_WORKERS,
)

In [55]:
evaaaaaaal = eval_lin_aug = lin_eval_aug(pmnist_loader_classifier, model_resnet18, n_classes=9)

KeyboardInterrupt: 