In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import torchvision
import cv2
from ZeroShotDataset import ZeroShotDataset
from params import *

from transformers import CLIPProcessor, CLIPModel

from ConditionedSegFormerPE import ConditionedSegFormer
from LossFunc import *
import pytorch_lightning as pl
import torchmetrics as tm

In [2]:
class CLIPConditionedSegFormer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch16')
        self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

        self.segformer = ConditionedSegFormer(
            ModelParams.INCHANNELS,
            ModelParams.WIDTHS,
            ModelParams.DEPTHS,
            512,
            768,
            ModelParams.PATCH_SIZES,
            ModelParams.OVERLAP_SIZES,
            ModelParams.NUM_HEADS,
            ModelParams.EXPANSION_FACTORS,
            ModelParams.DECODER_CHANNELS,
            ModelParams.SCALE_FACTORS
        )

        self.plot_every = TrainParams.PLOT_EVERY

        self.neloss = NELoss(LossParams.ALPHA, LossParams.BETA)

        self.acc = tm.Accuracy(task="binary", threshold=LossParams.THRESHOLD)
        self.dice = DiceLoss()
        self.iou = IoULoss(LossParams.THRESHOLD)
        self.f1score = tm.F1Score(task="binary", threshold=LossParams.THRESHOLD)

        # freeze CLIP
        for param in self.clip.parameters():
            param.requires_grad = False

    def forward(self, x, condition):
        condition = self.clip.text_model(condition).last_hidden_state
        pe = self.clip.vision_model(x).last_hidden_state

        return self.segformer(x, pe, condition)
    
    def training_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)

        t_y_hat = torch.where(y_hat > 0.5, 1, 0).long()
        t_y = torch.where(y > 0.5, 1, 0).long()

        acc = self.acc(t_y_hat, t_y)
        dice = self.dice(y_hat, y)
        iou = self.iou(y_hat, y)
        f1 = self.f1score(t_y_hat, t_y)

        self.log("test_acc", acc, prog_bar=True)
        self.log("test_dice", dice, prog_bar=True)
        self.log("test_iou", iou, prog_bar=True)
        self.log("test_f1", f1, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-7)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.99)
        return {
            "optimizer": optimizer, 
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }         
        }

In [3]:
train_df = pd.read_csv("train.csv")

In [4]:
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

In [5]:
train_dataset = ZeroShotDataset(
    df = train_df, 
    image_folder = 'C:/Datasets/COCO/P/ProcessedDataset/images/train/',
    mask_folder = "C:/Datasets/COCO/P/ProcessedDataset/masks/train/",
    mask_size = 56,
    templates = TrainParams.TEMPLATES, 
    unseen_clases = [], 
    image_processor = clip_processor, 
    tokenizer = clip_processor.tokenizer, 
    filter_unseen = False,
    filter_seen = True

)

In [6]:
loader = torch.utils.data.DataLoader(train_dataset, batch_size=TrainParams.BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn)

In [7]:
batch = next(iter(loader))

In [8]:
image, text, mask = batch

In [9]:
test_model = CLIPConditionedSegFormer()

In [10]:
out = test_model(image, text)

In [11]:
out.shape

torch.Size([16, 1, 56, 56])

In [12]:
mask.shape

torch.Size([16, 1, 56, 56])

In [13]:
test_model.training_step(batch, 0)

tensor(0.5437, grad_fn=<AddBackward0>)