# Contrastive Learning using FAISS and PyTorch on CIFAR-10 dataset

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import faiss

from torchvision.models import resnet18

class Encoder(nn.Module):
    def __init__(self, embedding_dim=128, pretrained=False):
        super().__init__()
        backbone = resnet18(pretrained=pretrained)
        # Remove the original FC layer
        self.features = nn.Sequential(*list(backbone.children())[:-1]) 
        # Add a new FC layer for embedding
        self.fc = nn.Linear(backbone.fc.in_features, embedding_dim)

    def forward(self, x):
        # x -> features
        feats = self.features(x)  # shape: Bx512x1x1 for resnet18
        feats = feats.view(feats.size(0), -1)
        embedding = self.fc(feats)
        return embedding, feats  # Return both the embedding and the last conv features


class SimpleEncoderModel(pl.LightningModule):
    def __init__(self, embedding_dim=128, num_classes=10, lr=1e-3):
        super().__init__()
        self.encoder = Encoder(embedding_dim=embedding_dim, pretrained=False)
        self.classifier = nn.Linear(embedding_dim, num_classes)
        self.save_hyperparameters()

    def forward(self, x):
        embedding, _ = self.encoder(x)
        return self.classifier(embedding)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('pretrain_loss', loss, prog_bar=True)
        self.log('pretrain_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class PretrainDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, pretrain_percentage=0.2, data_dir='../data/raw'):
        super().__init__()
        self.batch_size = batch_size
        self.pretrain_percentage = pretrain_percentage
        self.data_dir = data_dir

    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        transform = transforms.Compose([transforms.ToTensor()])
        cifar_full = datasets.CIFAR10(root=self.data_dir, train=True, transform=transform)
        total_train = len(cifar_full)
        pretrain_size = int(self.pretrain_percentage * total_train)
        self.pretrain_data, _ = random_split(cifar_full, [pretrain_size, total_train - pretrain_size],
                                             generator=torch.Generator().manual_seed(42))

    def train_dataloader(self):
        return DataLoader(self.pretrain_data, batch_size=self.batch_size, shuffle=True)

In [3]:
# Pretrain encoder
pretrain_dm = PretrainDataModule()
pretrain_model = SimpleEncoderModel(embedding_dim=128, num_classes=10, lr=1e-3)
trainer_pretrain = pl.Trainer(max_epochs=15, accelerator='gpu', devices=1, default_root_dir='../models/contrastive_learning')
trainer_pretrain.fit(pretrain_model, pretrain_dm)
torch.save(pretrain_model.encoder.state_dict(), "../models/pretrained_encoder.pth")


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type    | Params | Mode 
-----------------------------------------------
0 | encoder    | Encoder | 11.2 M | train
1 | classifier | Linear  | 1.3 K  | train
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.974    Total estimated model params size (MB)
70        Modules in train mode
0         Modules in eval mode


Epoch 14: 100%|██████████| 313/313 [00:18<00:00, 16.96it/s, v_num=14, pretrain_loss=0.385, pretrain_acc=0.938] 

`Trainer.fit` stopped: `max_epochs=15` reached.


Epoch 14: 100%|██████████| 313/313 [00:19<00:00, 16.45it/s, v_num=14, pretrain_loss=0.385, pretrain_acc=0.938]


In [11]:
class RelativeLearningModel(pl.LightningModule):
    def __init__(self, 
                 embedding_dim=128, 
                 k=5, 
                 num_classes=10, 
                 batch_size=32, 
                 freeze_encoder=True, 
                 lr=1e-3):
        
        super().__init__()
        self.save_hyperparameters()
        
        # Load the same encoder definition
        self.encoder = Encoder(embedding_dim=embedding_dim, pretrained=False)
        self.k = k
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.index = faiss.IndexFlatL2(embedding_dim)
        
        # Suppose we merge:
        # - The embedding (128 dim)
        # - Mean neighbor embedding (128 dim)
        # - Neighbor label distribution (k * num_classes)
        # - Possibly distances: (k)
        # Decide what features we use. For simplicity:
        # Combined input: embedding (128) + mean neighbor embedding (128) + neighbor label counts (num_classes) + distances summary (like mean distance)
        decision_input_dim = 128 + 128 + num_classes + 1  # Example setup
        self.decision_head = nn.Sequential(
            nn.Linear(decision_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def load_pretrained_encoder(self, path):
        state_dict = torch.load(path, map_location=self.device, weights_only=True)
        self.encoder.load_state_dict(state_dict)
        if self.hparams.freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False

    def build_faiss_index(self, dataloader):
        self.index.reset()
        self.gallery_embeddings = []
        self.gallery_labels = []
        
        self.encoder.eval()
        with torch.no_grad():
            for (x_faiss, y_faiss) in dataloader:
                x_faiss = x_faiss.to(self.device)
                emb, _ = self.encoder(x_faiss)
                emb = emb.cpu()
                self.gallery_embeddings.append(emb)
                self.gallery_labels.append(y_faiss)

        self.gallery_embeddings = torch.cat(self.gallery_embeddings, dim=0)
        self.gallery_labels = torch.cat(self.gallery_labels, dim=0)
        self.index.add(self.gallery_embeddings.numpy())
        self.encoder.train(mode=not self.hparams.freeze_encoder)

    def on_train_start(self):
        # Build FAISS index from faiss_data in datamodule
        faiss_loader = DataLoader(
            self.trainer.datamodule.faiss_data,
            batch_size=self.batch_size,
        )
        self.build_faiss_index(faiss_loader)

    def get_neighbor_features(self, embeddings):
        # embeddings: BxD
        distances, indices = self.index.search(embeddings.detach().cpu().numpy(), self.k)
        neighbor_labels = self.gallery_labels[indices]  # BxK
        neighbor_embs = self.gallery_embeddings[indices]  # BxKxD

        # Convert to torch
        distances = torch.tensor(distances, device=self.device, dtype=torch.float32)

        # Aggregate neighbor info
        mean_neighbor_emb = neighbor_embs.mean(dim=1).to(self.device)   # BxD
        label_counts = torch.zeros(embeddings.size(0), self.num_classes, device=self.device)
        for cls in range(self.num_classes):
            label_counts[:, cls] = (neighbor_labels == cls).sum(dim=1)

        # Distances: we can take mean distance
        mean_dist = distances.mean(dim=1, keepdim=True)  # Bx1

        return mean_neighbor_emb, label_counts, mean_dist

    def shared_step(self, batch):
        x, y = batch
        embeddings, _ = self.encoder(x)  # embeddings: BxD
        mean_neighbor_emb, label_counts, mean_dist = self.get_neighbor_features(embeddings)

        # Concatenate features: embeddings + mean_neighbor_emb + label_counts + mean_dist
        combined_features = torch.cat([embeddings, mean_neighbor_emb, label_counts, mean_dist], dim=-1)
        logits = self.decision_head(combined_features)

        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log('train_loss', loss)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        if self.index.ntotal == 0:
            return
        loss, acc = self.shared_step(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        if self.index.ntotal == 0:
            return
        loss, acc = self.shared_step(batch)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class CIFAR10FaissDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, faiss_percentage=0.1, val_percentage=0.1, data_dir='../data/raw'):
        super().__init__()
        self.batch_size = batch_size
        self.faiss_percentage = faiss_percentage
        self.val_percentage = val_percentage
        self.data_dir = data_dir

    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        transform = transforms.Compose([transforms.ToTensor()])
        cifar_full = datasets.CIFAR10(root=self.data_dir, train=True, transform=transform)
        total_train = len(cifar_full)
        faiss_size = int(self.faiss_percentage * total_train)
        remain = total_train - faiss_size
        self.faiss_data, remain_data = random_split(cifar_full, [faiss_size, remain],
                                                    generator=torch.Generator().manual_seed(42))
        val_size = int(self.val_percentage * remain)
        train_size = remain - val_size
        self.train_data, self.val_data = random_split(remain_data, [train_size, val_size],
                                                      generator=torch.Generator().manual_seed(42))
        self.test_data = datasets.CIFAR10(root=self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size)

In [12]:
from pytorch_lightning.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min'
)

main_dm = CIFAR10FaissDataModule(faiss_percentage=0.1, val_percentage=0.1)
model = RelativeLearningModel(freeze_encoder=False)
model.load_pretrained_encoder("../models/pretrained_encoder.pth")

trainer = pl.Trainer(max_epochs=15, callbacks=[early_stopping], accelerator='gpu', devices=1, default_root_dir='../models/contrastive_learning')
trainer.fit(model, main_dm)
trainer.test(model, datamodule=main_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params | Mode 
-----------------------------------------------------
0 | encoder       | Encoder    | 11.2 M | train
1 | decision_head | Sequential | 71.2 K | train
-----------------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.253    Total estimated model params size (MB)
73        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 1266/1266 [01:24<00:00, 15.01it/s, v_num=18, train_acc=0.300] 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/141 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/141 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/141 [00:00<00:02, 50.62it/s][A
Validation DataLoader 0:   1%|▏         | 2/141 [00:00<00:02, 47.14it/s][A
Validation DataLoader 0:   2%|▏         | 3/141 [00:00<00:03, 45.71it/s][A
Validation DataLoader 0:   3%|▎         | 4/141 [00:00<00:03, 44.51it/s][A
Validation DataLoader 0:   4%|▎         | 5/141 [00:00<00:03, 44.20it/s][A
Validation DataLoader 0:   4%|▍         | 6/141 [00:00<00:03, 43.99it/s][A
Validation DataLoader 0:   5%|▍         | 7/141 [00:00<00:03, 43.79it/s][A
Validation DataLoader 0:   6%|▌         | 8/141 [00:00<00:03, 43.65it/s][A
Validation DataLoader 0:   6%|▋         | 9/141 [00:00<00:03, 43.50it/s][A
Validation DataLoader 0:   7%|▋         | 10/141 [00:00

Metric val_loss improved. New best score: 2.061


Epoch 1: 100%|██████████| 1266/1266 [01:23<00:00, 15.15it/s, v_num=18, train_acc=0.200, val_loss=2.060, val_acc=0.232] 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/141 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/141 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/141 [00:00<00:03, 45.78it/s][A
Validation DataLoader 0:   1%|▏         | 2/141 [00:00<00:03, 42.82it/s][A
Validation DataLoader 0:   2%|▏         | 3/141 [00:00<00:03, 42.91it/s][A
Validation DataLoader 0:   3%|▎         | 4/141 [00:00<00:03, 43.00it/s][A
Validation DataLoader 0:   4%|▎         | 5/141 [00:00<00:03, 43.61it/s][A
Validation DataLoader 0:   4%|▍         | 6/141 [00:00<00:03, 43.87it/s][A
Validation DataLoader 0:   5%|▍         | 7/141 [00:00<00:03, 44.16it/s][A
Validation DataLoader 0:   6%|▌         | 8/141 [00:00<00:02, 44.41it/s][A
Validation DataLoader 0:   6%|▋         | 9/141 [00:00<00:02, 44.63it/s][A
Validation DataLoader 0:


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined