In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import torchvision
import cv2
from ZeroShotDataset import ZeroShotDataset
from params import *

from transformers import CLIPProcessor, CLIPModel

from ConditionedSegFormerPE import ConditionedSegFormer
from LossFunc import *
import pytorch_lightning as pl
import torchmetrics as tm

# pytorch random train test split
from torch.utils.data import random_split

KeyboardInterrupt: 

In [None]:
class CLIPConditionedSegFormer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch16')
        self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

        self.segformer = ConditionedSegFormer(
            ModelParams.INCHANNELS,
            ModelParams.WIDTHS,
            ModelParams.DEPTHS,
            512,
            768,
            ModelParams.PATCH_SIZES,
            ModelParams.OVERLAP_SIZES,
            ModelParams.NUM_HEADS,
            ModelParams.EXPANSION_FACTORS,
            ModelParams.DECODER_CHANNELS,
            ModelParams.SCALE_FACTORS
        )

        self.plot_every = TrainParams.PLOT_EVERY

        self.neloss = NELoss(LossParams.ALPHA, LossParams.BETA)
        # self.neloss = FocalLoss()

        self.acc = tm.Accuracy(task="binary", threshold=LossParams.THRESHOLD)
        self.dice = DiceLoss()
        self.iou = IoULoss(LossParams.THRESHOLD)
        self.f1score = tm.F1Score(task="binary", threshold=LossParams.THRESHOLD)

        # freeze CLIP
        for param in self.clip.parameters():
            param.requires_grad = False

    def forward(self, x, condition):
        condition = self.clip.text_model(condition).last_hidden_state
        pe = self.clip.vision_model(x).last_hidden_state

        out = self.segformer(x, pe, condition)
        return out
    
    def training_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        
        if self.global_step % self.plot_every == 0: 
            y = y.repeat(1, 3, 1, 1)
            y_hat = torch.sigmoid(y_hat)
            y_hat = y_hat.repeat(1, 3, 1, 1)

            # x_grid = torchvision.utils.make_grid(train_dataset.image_inverse_transform(x))
            # self.logger.experiment.add_image('train_sample_image', x_grid, self.global_step)

            grid = torchvision.utils.make_grid(torch.cat([y, y_hat], dim=0))
            self.logger.experiment.add_image('train_sample_mask', grid, self.global_step)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)
        loss = self.neloss(y_hat, y)

        t_y_hat = torch.where(y_hat > 0.5, 1, 0).long()
        t_y = torch.where(y > 0.5, 1, 0).long()

        acc = self.acc(t_y_hat, t_y)
        dice = self.dice(y_hat, y)
        iou = self.iou(y_hat, y)
        f1 = self.f1score(t_y_hat, t_y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_dice", dice, prog_bar=True)
        self.log("val_iou", iou, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, condition, y = batch
        y_hat = self(x, condition)

        t_y_hat = torch.where(y_hat > 0.5, 1, 0).long()
        t_y = torch.where(y > 0.5, 1, 0).long()

        acc = self.acc(t_y_hat, t_y)
        dice = self.dice(y_hat, y)
        iou = self.iou(y_hat, y)
        f1 = self.f1score(t_y_hat, t_y)

        self.log("test_acc", acc, prog_bar=True)
        self.log("test_dice", dice, prog_bar=True)
        self.log("test_iou", iou, prog_bar=True)
        self.log("test_f1", f1, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-7)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.99)
        return {
            "optimizer": optimizer, 
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }         
        }

In [None]:
train_df = pd.read_csv("train.csv")
val_df = pd.read_csv("val.csv")
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

In [None]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
complete_dataset = ZeroShotDataset(
    df = train_df, 
    # image_folder = 'C:/Datasets/COCO/P/ProcessedDataset/images/train/',
    image_folder = 'ProcessedDataset/images/train/',
    # mask_folder = "C:/Datasets/COCO/P/ProcessedDataset/masks/train/",
    mask_folder = 'ProcessedDataset/masks/train/',
    mask_size = 56,
    templates = TrainParams.TEMPLATES, 
    unseen_classes = TrainParams.UNSEEN_CLASSES, 
    image_processor = clip_processor, 
    tokenizer = clip_processor.tokenizer, 
    filter_unseen = False,
    filter_seen = True
)

train_size = int(TrainParams.TRAIN_VAL_SPLIT * len(complete_dataset))
val_size = len(complete_dataset) - train_size
train_dataset, val_dataset = random_split(complete_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))

In [None]:
print(f"Number of training images: {len(train_dataset)}")   
print(f"Number of val images: {len(val_dataset)}")   

Number of training images: 298426
Number of val images: 74607


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TrainParams.BATCH_SIZE, shuffle=True, collate_fn=complete_dataset.collate_fn, num_workers=TrainParams.NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=TrainParams.BATCH_SIZE, shuffle=False, collate_fn=complete_dataset.collate_fn, num_workers=TrainParams.NUM_WORKERS)

In [None]:
test_model = CLIPConditionedSegFormer()

In [None]:
# torch.set_float32_matmul_precision('medium')

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./checkpoints',
    filename='{epoch}-{val_loss:.2f}-{val_r2:.2f}',
    save_top_k=3,
    monitor='val_iou',
    every_n_epochs=1,
    mode='max'
)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=30,
    callbacks=[
        checkpoint_callback,
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(test_model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: c:\Users\david\OneDrive\Documents\GitHub\TextualSegFormer\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                 | Params
---------------------------------------------------
0 | clip      | CLIPModel            | 149 M 
1 | segformer | ConditionedSegFormer | 18.5 M
2 | neloss    | NELoss               | 0     
3 | acc       | Accuracy             | 0     
4 | dice      | DiceLoss             | 0     
5 | iou       | IoULoss              | 0     
6 | f1score   | F1Score              | 0     
---------------------------------------------------
18.5 M    Trainable param

Training: 0it [00:00, ?it/s]