In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
import pandas as pd

import torchvision
import pytorch_lightning as pl

from ConditionedSegFormer import ConditionedSegFormer
import torchmetrics as tm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from LossFunc import *
from params import *

In [2]:
class CLIPConditionedSegFormer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').text_model

        self.segformer = ConditionedSegFormer(
            ModelParams.INCHANNELS,
            ModelParams.WIDTHS,
            ModelParams.DEPTHS,
            512,
            ModelParams.PATCH_SIZES,
            ModelParams.OVERLAP_SIZES,
            ModelParams.REDUCTION_RATIOS,
            ModelParams.NUM_HEADS,
            ModelParams.EXPANSION_FACTORS,
            ModelParams.DECODER_CHANNELS,
            ModelParams.SCALE_FACTORS
        )

        self.plot_every = TrainParams.PLOT_EVERY

        self.neloss = NELoss(LossParams.ALPHA, LossParams.BETA)

        self.acc = tm.Accuracy(task="binary", threshold=LossParams.THRESHOLD)
        self.dice = DiceLoss()
        self.iou = IoULoss(LossParams.THRESHOLD)
        self.f1score = tm.F1Score(task="binary", threshold=LossParams.THRESHOLD)

        # freeze CLIP
        for param in self.clip.parameters():
            param.requires_grad = False

    def forward(self, x, condition):
        condition = self.clip(condition).last_hidden_state
        return self.segformer(x, condition)
    
    def training_step(self, batch, batch_idx):
        x, y, condition = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)

        if self.global_step % self.plot_every == 0: 
            y = y.repeat(1, 3, 1, 1)
            y_hat = y_hat.repeat(1, 3, 1, 1)

            # x_grid = torchvision.utils.make_grid(train_dataset.image_inverse_transform(x))
            # self.logger.experiment.add_image('train_sample_image', x_grid, self.global_step)

            grid = torchvision.utils.make_grid(torch.cat([y, y_hat], dim=0))
            self.logger.experiment.add_image('train_sample_mask', grid, self.global_step)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, condition = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y, condition = batch
        y_hat = self(x, condition)

        t_y_hat = torch.where(y_hat > 0.5, 1, 0).long()
        t_y = torch.where(y > 0.5, 1, 0).long()

        acc = self.acc(t_y_hat, t_y)
        dice = self.dice(y_hat, y)
        iou = self.iou(y_hat, y)
        f1 = self.f1score(t_y_hat, t_y)

        self.log("test_acc", acc, prog_bar=True)
        self.log("test_dice", dice, prog_bar=True)
        self.log("test_iou", iou, prog_bar=True)
        self.log("test_f1", f1, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-7)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.99)
        return {
            "optimizer": optimizer, 
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }         
        }

In [3]:
model = CLIPConditionedSegFormer()

In [4]:
image = torch.randn(1, 3, 256, 256)
condition = torch.randint(0, 10000, (1, 77))


In [5]:
out = model(image, condition)

In [6]:
out.shape

torch.Size([1, 1, 64, 64])

In [7]:
trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=TrainParams.EPOCHS,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            filename="clip_conditioned_segformer-{epoch:02d}-{val_loss:.2f}",
            save_top_k=3,
            mode="min"
        ),
        LearningRateMonitor(logging_interval='step')
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
