In [4]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from diffusers import UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMScheduler
@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    saved_model = "" #saved model path
    class_num = 10
    batch_size= 512
    seed = 24
config = TrainingConfig()

In [5]:
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

## Extracting the output of the mid block of the Unet model (backbone)

In [6]:

class NewModel(nn.Module):
    def __init__(self, conv_in, time_proj, down_blocks, mid_block, dtype, time_embedding, config):
        super(NewModel, self).__init__()
        self.conv_in=conv_in
        self.time_proj=time_proj
        self.down_blocks = down_blocks
        self.mid_block = mid_block
        self.dtype=dtype
        self.time_embedding=time_embedding
        self.config=config

    def forward(self, sample):
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0
        timesteps = 0
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
        t_emb = self.time_proj(timesteps)
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)


        # print("time done")
        # 2. pre-process
        skip_sample = sample
        sample = self.conv_in(sample)
        # print("pre-process done")
        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples
        # print("down done")
        # 4. mid
        sample = self.mid_block(sample, emb)

        # Return the output from the mid-block
        return sample

In [7]:
Unet = UNet2DModel.from_pretrained(config.saved_model)
AlteredUnetModel=NewModel(Unet.conv_in, Unet.time_proj, Unet.down_blocks,Unet.mid_block, Unet.dtype, Unet.time_embedding, Unet.config)

## Loading CIFAR 10 dataset

In [8]:
from torchvision import datasets
datasetCIFAR10 = datasets.CIFAR10(root='/artifacts/datasetcifar10train', train=True, download=True, transform=preprocess)
datasetCIFAR10test = datasets.CIFAR10(root='/artifacts/datasetcifar10test', train=False, download=True, transform=preprocess)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
dataloader_train = torch.utils.data.DataLoader(
    datasetCIFAR10,
    batch_size=config.batch_size,
    shuffle=True
)
dataloader_val = torch.utils.data.DataLoader(
    datasetCIFAR10test,
    batch_size=config.batch_size,
    shuffle=False,
)

# K-NN classification. 

This was not written by me, the source can be found here: https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking/knn_classifier.py



In [10]:
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Module

from lightly.utils.benchmarking import knn_predict
from lightly.utils.benchmarking.topk import mean_topk_accuracy


class KNNClassifier(LightningModule):
    def __init__(
        self,
        model: Module,
        num_classes: int,
        knn_k: int = 200,
        knn_t: float = 0.1,
        topk: Tuple[int, ...] = (1, 5),
        feature_dtype: torch.dtype = torch.float32,
        normalize: bool = True,
    ):
        super().__init__()
        self.save_hyperparameters(
            {
                "num_classes": num_classes,
                "knn_k": knn_k,
                "knn_t": knn_t,
                "topk": topk,
                "feature_dtype": str(feature_dtype),
            }
        )
        self.model = model
        self.num_classes = num_classes
        self.knn_k = knn_k
        self.knn_t = knn_t
        self.topk = topk
        self.feature_dtype = feature_dtype
        self.normalize = normalize

        self._train_features = []
        self._train_targets = []
        self._train_features_tensor: Optional[Tensor] = None
        self._train_targets_tensor: Optional[Tensor] = None

    @torch.no_grad()
    def training_step(self, batch, batch_idx) -> None:
        images, targets = batch[0], batch[1]
        features = self.model.forward(images).flatten(start_dim=1)
        if self.normalize:
            features = F.normalize(features, dim=1)
        features = features.to(self.feature_dtype)
        self._train_features.append(features.cpu())
        self._train_targets.append(targets.cpu())

    def validation_step(self, batch, batch_idx) -> None:
        if self._train_features_tensor is None or self._train_targets_tensor is None:
            return

        images, targets = batch[0], batch[1]
        features = self.model.forward(images).flatten(start_dim=1)
        if self.normalize:
            features = F.normalize(features, dim=1)
        features = features.to(self.feature_dtype)
        predicted_classes = knn_predict(
            feature=features,
            feature_bank=self._train_features_tensor,
            feature_labels=self._train_targets_tensor,
            num_classes=self.num_classes,
            knn_k=self.knn_k,
            knn_t=self.knn_t,
        )
        topk = mean_topk_accuracy(
            predicted_classes=predicted_classes, targets=targets, k=self.topk
        )
        log_dict = {f"val_top{k}": acc for k, acc in topk.items()}
        self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=len(targets))

    def on_validation_epoch_start(self) -> None:
        if self._train_features and self._train_targets:
            # Features and targets have size (world_size, batch_size, dim) and
            # (world_size, batch_size) after gather. For non-distributed training,
            # features and targets have size (batch_size, dim) and (batch_size,).
            features = self.all_gather(torch.cat(self._train_features, dim=0))
            self._train_features = []
            targets = self.all_gather(torch.cat(self._train_targets, dim=0))
            self._train_targets = []
            # Reshape to (dim, world_size * batch_size)
            features = features.flatten(end_dim=-2).t().contiguous()
            self._train_features_tensor = features.to(self.device)
            # Reshape to (world_size * batch_size,)
            targets = targets.flatten().t().contiguous()
            self._train_targets_tensor = targets.to(self.device)

    def on_train_epoch_start(self) -> None:
        # Set model to eval mode to disable norm layer updates.
        self.model.eval()

        # Reset features and targets.
        self._train_features = []
        self._train_targets = []
        self._train_features_tensor = None
        self._train_targets_tensor = None

    def configure_optimizers(self) -> None:
        # configure_optimizers must be implemented for PyTorch Lightning. Returning None
        # means that no optimization is performed.
        pass

In [11]:
knn_classifier = KNNClassifier(
    model=AlteredUnetModel,
    num_classes=10, # Assuming you have 10 classes for CIFAR10
    knn_k=50,
    knn_t=0.1,
    topk=(1,2,3),
    feature_dtype=torch.float32,
    normalize=True,
)

In [None]:
import os
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

# Hardcode the wandb API key
os.environ["WANDB_API_KEY"] = ""   #Your wandb key


wandb_logger = pl.loggers.WandbLogger(
    name="", project=""    #Your run name and project name
)
trainer = pl.Trainer(
    max_epochs=2, devices=1, accelerator="cuda", logger=[wandb_logger]
)
trainer.fit(knn_classifier, dataloader_train, dataloader_val)
wandb.finish()