In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import PIL 
from tqdm import tqdm
import matplotlib.pyplot as plt
# tsne and pca
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import untils
from MiniDeepTaxonNet import CobwebNN, CobwebNNTreeLayer, TestModel
import argparse
import os
import sys

In [7]:
# load mnist
download = True
dataset_class = datasets.MNIST
mnist_transform = [transforms.ToTensor()]
# mnist_transform.append(transforms.Normalize((0.1307,), (0.3081,)))

dataset_transform = transforms.Compose(mnist_transform)
mnist_train = dataset_class('data/MNIST', train=True, download=download, transform=dataset_transform)
mnist_test = dataset_class('data/MNIST', train=False, download=download, transform=dataset_transform)

mnist_train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)



In [24]:
strong_transform = transforms.Compose([
    # Random rotation between -30 and 30 degrees.
    transforms.RandomRotation(degrees=30),
    # Random affine: random translation, scaling, and shear.
    transforms.RandomAffine(
        degrees=90,                # no additional rotation here
        translate=(0.4, 0.4),     # translate up to 20% of image size
        scale=(0.4, 1.7),         # scale between 80% and 120%
        shear=50                 # shear by up to 20 degrees
    ),
    # radnom crop
    transforms.RandomCrop(28, padding=4),
    # Random horizontal flip.
    transforms.RandomHorizontalFlip(),
    # Randomly adjust brightness and contrast.
    # transforms.ColorJitter(brightness=0.5, contrast=0.5),
    transforms.ToTensor(),
])
no_transform = transforms.Compose([
    transforms.ToTensor(),
])

random_noise = transforms.Compose([
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.2),
])
def simclr_transform(x):
    # We apply 'train_transform' twice to get two distinct views
    # from the same original image.
    return no_transform(x), strong_transform(x)

train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    # We'll override how __getitem__ works to produce two augmented views
    transform=None
)
# Because we want each sample to return a pair (x1, x2), we define a custom Dataset wrapper:
class SimCLRDatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, transform_fn):
        self.dataset = dataset
        self.transform_fn = transform_fn
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        # Apply our simclr_transform to get x1, x2
        x1, x2 = self.transform_fn(img)
        # add random noise to x2
        # x2 = random_noise(x2)
        return x1, x2, label  # we can return the original label if needed (not used in unsupervised training)
    
    def __len__(self):
        return len(self.dataset)

train_dataset_simclr = SimCLRDatasetWrapper(train_dataset, simclr_transform)
mnist_simclr_train_loader = DataLoader(train_dataset_simclr, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)


In [25]:
class VGGLikeMNIST(nn.Module):
    def __init__(self, num_classes=10):
        super(VGGLikeMNIST, self).__init__()
        self.features = nn.Sequential(
            # Block 1: 28x28 -> 14x14
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # Input: 1 channel, Output: 64
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Block 2: 14x14 -> 7x7
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Block 3: 7x7 -> 3x3 (7//2 = 3 using floor division)
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        # After three pooling layers, the spatial size becomes roughly 3x3 (if using floor division).
        # Flatten the feature map and use a simple classifier.
        self.classifier = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        # Flatten the feature map into a vector for the classifier
        x = x.view(x.size(0), -1)
        # x = self.classifier(x)
        return x

In [26]:
class IIC(nn.Module):
    def __init__(self, n_classes):
        super(IIC, self).__init__()
        self.n_classes = n_classes
        # CNN encoder for MNIST
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 32, 3, 1),
        #     nn.ReLU(),
        #     nn.Conv2d(32, 64, 3, 1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(64, 128, 3, 1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(128, 256, 3, 1),
        #     nn.ReLU()
        # ) # output shape: 256 x 3 x 3
        self.encoder = VGGLikeMNIST()
        # drop the projection head
        # projection head
        self.projection_head = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, n_classes)
        )
        # # clustering head
        # self.cluster_head = nn.Sequential(
        #     nn.Linear(64, n_classes)
        # )

    def forward(self, x):
        x = self.encoder(x)
        # print(x.shape)
        # x = x.view(x.size(0), -1)
        x = self.projection_head(x)
        # x = F.softmax(self.cluster_head(x), dim=1) # batch x n_classes
        x = F.softmax(x, dim=1)
        return x


In [28]:
# train on MNIST

epochs = 10
n_classes = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = IIC(n_classes).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    if epoch % 2 == 0:
        # tsne on test set
        test_features = []
        test_labels = []
        for x, y in mnist_test_loader:
            x = x.to(device)
            z = model.encoder(x)
            z = z.view(z.size(0), -1)
            test_features.append(z.detach().cpu().numpy())
            test_labels.append(y.detach().cpu().numpy())
            break
        test_features = np.concatenate(test_features, axis=0)
        test_labels = np.concatenate(test_labels, axis=0)
        tsne = TSNE(n_components=2)
        tsne_features = tsne.fit_transform(test_features)
        plt.figure()
        plt.scatter(tsne_features[:, 0], tsne_features[:, 1], c=test_labels, cmap='tab10')
        plt.savefig(f'IIC.png')
        plt.close()

    for x, gx, _ in tqdm(mnist_simclr_train_loader):
        x = x.to(device)
        # display x
        # plt.imshow(x[0].squeeze().cpu().numpy())
        # random augmentation
        #### Your code here ####
        gx = gx.to(device)

        # display gx
        # plt.imshow(gx[0].squeeze().cpu().numpy())
        # break

        optimizer.zero_grad()

        # forward pass
        z = model(x)
        zt = model(gx)
        # print(z, zt)
        P = (z.unsqueeze(2) * zt.unsqueeze(1)).sum(dim=0) # n_classes x n_classes
        # print(P)
        P = ((P + P.t()) / 2) / P.sum()
        # print(P)
        # break
        eps = 1e-5
        P[(P < eps)] = eps
        Pi = P.sum(dim=1).view(-1, 1).expand(n_classes, n_classes)
        # clamp Pi
        Pi = torch.clamp(Pi, min=eps)
        # print(Pi)
        Pj = P.sum(dim=0).view(1, -1).expand(n_classes, n_classes)
        Pj = torch.clamp(Pj, min=eps)
        # Pj[(Pj < eps)] = eps
        # print(Pj)
        # break
        # loss = -(P * (torch.log(Pi) + torch.log(Pj) - torch.log(P))).sum()
        loss = -(P * (torch.log(P) - torch.log(Pi) - torch.log(Pj))).sum()
        # break
        # print(loss.item())
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')





100%|██████████| 469/469 [00:17<00:00, 26.27it/s]


Epoch 1/10, Loss: 5.437214056769335e-08


 17%|█▋        | 81/469 [00:03<00:15, 24.97it/s]


KeyboardInterrupt: 

In [13]:
# test the assignment probability
model.eval()
test_features = []
test_labels = []
for x, y in mnist_test_loader:
    x = x.to(device)
    z = model(x)
    y_pred = z.argmax(dim=1).detach().cpu().numpy()
    acc = (y_pred == y).float().mean()
    print(f'Accuracy: {acc.item()}')
    break

Accuracy: 0.09375
