In [1]:
from dataclasses import dataclass
from diffusers import UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMScheduler

@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    saved_model = ""
    class_num = 10
    batch_size= 512
    seed = 24


config = TrainingConfig()

# Dataset preprocess

In [2]:
from torchvision import transforms

preprocessTrain = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
preprocessVal = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
from torchvision import datasets
datasetCIFAR10 = datasets.CIFAR10(root='/artifacts/datasetcifar10train', train=True, download=True, transform=preprocessTrain)
datasetCIFAR10test = datasets.CIFAR10(root='/artifacts/datasetcifar10test', train=False, download=True, transform=preprocessVal)

In [4]:
import torch
dataloader_train = torch.utils.data.DataLoader(
    datasetCIFAR10,
    batch_size=config.batch_size,
    shuffle=True
)
dataloader_val = torch.utils.data.DataLoader(
    datasetCIFAR10test,
    batch_size=config.batch_size,
    shuffle=False,
)

# Getting only the encoder and mid part from the U-net (backbone)

In [5]:

import torch.nn as nn
class NewModel(nn.Module):
    def __init__(self, conv_in, time_proj, down_blocks, mid_block, dtype, time_embedding, config,up_block):
        super(NewModel, self).__init__()
        self.conv_in=conv_in
        self.time_proj=time_proj
        self.down_blocks = down_blocks
        self.mid_block = mid_block
        self.dtype=dtype
        self.time_embedding=time_embedding
        self.config=config
        self.up_blocks=up_block

    def forward(self, sample):
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        timesteps = 0
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
        t_emb = self.time_proj(timesteps)
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)

        # 2. pre-process
        skip_sample = sample
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb)

        # Return the output from the mid-block
        return sample

In [6]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [7]:
Unet = UNet2DModel.from_pretrained(config.saved_model)
AlteredUnetModel=NewModel(Unet.conv_in, Unet.time_proj, Unet.down_blocks,Unet.mid_block, Unet.dtype, Unet.time_embedding, Unet.config, Unet.up_blocks)

# Linear classification on the backbone

In [8]:
import torch
import wandb
import pytorch_lightning as pl
from torch.nn import AdaptiveAvgPool2d
global_avg_pool = AdaptiveAvgPool2d((1, 1))
from lightly.transforms import utils

class Classifier(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        # use the pretrained backbone
        self.backbone = backbone

        # freeze the backbone
        for p in self.backbone.parameters():
            p.requires_grad = False

        # create a linear layer for our downstream classification model
        self.fc = nn.Linear(256, config.class_num)
        self.criterion = nn.CrossEntropyLoss()
        self.validation_step_outputs = []

        self.example_input_array = torch.zeros(1, 3, config.image_size, config.image_size)

    def forward(self, x):
            mid_block_output = self.backbone(x)
            pooled_output = global_avg_pool(mid_block_output)
            # Flatten the output correctly
            y_hat = pooled_output.view(pooled_output.size(0), -1)
            y_hat = self.fc(y_hat)
            return y_hat

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss_fc", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        self.validation_step_outputs.append((num, correct))
        return num, correct

    def on_validation_epoch_end(self):
        # calculate and log top1 accuracy
        if self.validation_step_outputs:
            total_num = 0
            total_correct = 0
            for num, correct in self.validation_step_outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.fc.parameters(), lr=0.001)
        return optim

# Training

In [None]:
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

# Hardcode the wandb API key
os.environ["WANDB_API_KEY"] = ""

#Freezing the weight of the back bone
AlteredUnetModel.eval()
classifier = Classifier(AlteredUnetModel)

wandb_logger = pl.loggers.WandbLogger(
    name="", project="" #Your run name and project name
)
trainer = pl.Trainer(
    max_epochs=30, devices=1, accelerator="cuda", logger=[wandb_logger]
)
trainer.fit(classifier, dataloader_train, dataloader_val)
wandb.finish()