In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import pandas as pd
import numpy as np
import torchvision
import cv2
import pytorch_lightning as pl
import torchmetrics as tm

from torch.utils.data import Dataset
from ZeroShotDataset import ZeroShotDataset
from params import *
from DatasetModeling import *
from transformers import CLIPProcessor, CLIPModel
from LossFunc import *
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.utils.data import random_split
from CLIPConditionedSegFormerModel import CLIPConditionedSegFormer

In [2]:
torch.set_float32_matmul_precision('medium')

In [3]:
train_df = pd.read_csv(TrainParams.TRAIN_CSV_PATH)
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

In [4]:
train_df

Unnamed: 0,image,mask,label,category_id
0,000000558840.jpg,000000558840_0.jpg,hot dog,58
1,000000558840.jpg,000000558840_1.jpg,bottle,44
2,000000558840.jpg,000000558840_2.jpg,cup,47
3,000000558840.jpg,000000558840_3.jpg,person,1
4,000000558840.jpg,000000558840_4.jpg,spoon,50
...,...,...,...,...
973173,000000581929.jpg,000000581929_973173.jpg,bush,97
973174,000000581929.jpg,000000581929_973174.jpg,cage,99
973175,000000581929.jpg,000000581929_973175.jpg,clouds,106
973176,000000581929.jpg,000000581929_973176.jpg,grass,124


In [5]:
balanced_train_df = balance_dataset(train_df, ratio=2)

# inductive_dataset_train = inductive_dataset(balanced_train_df, TrainParams.UNSEEN_CLASSES)
# inductive_dataset_val = inductive_dataset(balanced_train_df, TrainParams.SEEN_CLASSES)

# print(len(inductive_dataset_train), len(inductive_dataset_val))

transductive_dataset_train = transductive_dataset(balanced_train_df, TrainParams.SEEN_CLASSES)
transductive_dataset_val = transductive_dataset(balanced_train_df, TrainParams.UNSEEN_CLASSES)

print(len(transductive_dataset_train), len(transductive_dataset_val))

# print min and max frequencies
label_freqs = balanced_train_df["label"].value_counts()
print(label_freqs.min(), label_freqs.max(), label_freqs.mean())

121 64115
35012 3388
121 242 240.83625730994152


In [6]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [7]:
print(TrainParams.DATASET_IMAGE_FOLDER_TRAIN, TrainParams.DATASET_MASK_FOLDER_TRAIN,)

ProcessedDatasetStuff512/images/train/ ProcessedDatasetStuff512/masks/train/


In [8]:
train_dataset = ZeroShotDataset(
    # df = inductive_dataset_train, 
    df = transductive_dataset_train,
    image_folder = TrainParams.DATASET_IMAGE_FOLDER_TRAIN,
    mask_folder = TrainParams.DATASET_MASK_FOLDER_TRAIN,
    image_size = TrainParams.IMAGE_DIM,
    mask_size = TrainParams.MASK_SIZE,
    templates = TrainParams.TEMPLATES, 
    unseen_classes = TrainParams.UNSEEN_CLASSES, 
    image_processor = clip_processor, 
    tokenizer = clip_processor.tokenizer, 
    filter_unseen = False,
    filter_seen = False # True
)

val_dataset = ZeroShotDataset(
    # df = inductive_dataset_val, 
    df = transductive_dataset_val,
    image_folder = TrainParams.DATASET_IMAGE_FOLDER_TRAIN,
    mask_folder = TrainParams.DATASET_MASK_FOLDER_TRAIN,
    image_size = TrainParams.IMAGE_DIM,
    mask_size = TrainParams.MASK_SIZE,
    templates = TrainParams.TEMPLATES, 
    unseen_classes = TrainParams.UNSEEN_CLASSES, 
    image_processor = clip_processor, 
    tokenizer = clip_processor.tokenizer, 
    filter_unseen = False, # True
    filter_seen = False
)

In [9]:
print("Batch size:", TrainParams.BATCH_SIZE)
print("Num workers:", TrainParams.NUM_WORKERS)

Batch size: 8
Num workers: 2


In [10]:
print(f"Number of training images: {len(train_dataset)}")   
print(f"Number of val images: {len(val_dataset)}")   

Number of training images: 35012
Number of val images: 3388


In [11]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TrainParams.BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=TrainParams.NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=TrainParams.BATCH_SIZE, shuffle=False, collate_fn=val_dataset.collate_fn, num_workers=TrainParams.NUM_WORKERS)
test_model = CLIPConditionedSegFormer()

In [12]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_iou',
    filename='transformer-{epoch:02d}-{val_loss:.3f}-{val_iou:.2f}',
    save_top_k=3,
    mode='max',
    # dirpath='checkpoints/',
    save_last=True,
    verbose=True
)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=20,
    callbacks=[
        checkpoint_callback,
        LearningRateMonitor(logging_interval='step')
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(test_model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                 | Params
---------------------------------------------------
0 | clip      | CLIPModel            | 149 M 
1 | segformer | ConditionedSegFormer | 13.4 M
2 | neloss    | NELoss               | 0     
3 | acc       | Accuracy             | 0     
4 | dice      | DiceLoss             | 0     
5 | iou       | IoULoss              | 0     
6 | f1score   | F1Score              | 0     
---------------------------------------------------
13.4 M    Trainable params
149 M     Non-trainable params
163 M     Total params
652.183   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 4377: 'val_iou' reached 0.11182 (best 0.11182), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=00-val_loss=0.472-val_iou=0.11.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 8754: 'val_iou' reached 0.14373 (best 0.14373), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=01-val_loss=0.455-val_iou=0.14.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 13131: 'val_iou' reached 0.12440 (best 0.14373), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=02-val_loss=0.483-val_iou=0.12.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 17508: 'val_iou' reached 0.16942 (best 0.16942), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=03-val_loss=0.440-val_iou=0.17.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 21885: 'val_iou' was not in top 3


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 26262: 'val_iou' reached 0.13865 (best 0.16942), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=05-val_loss=0.477-val_iou=0.14.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 30639: 'val_iou' reached 0.15330 (best 0.16942), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=06-val_loss=0.467-val_iou=0.15.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 35016: 'val_iou' reached 0.15085 (best 0.16942), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=07-val_loss=0.482-val_iou=0.15.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 39393: 'val_iou' was not in top 3


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 43770: 'val_iou' reached 0.15871 (best 0.16942), saving model to 'c:\\Users\\david\\OneDrive\\Documents\\GitHub\\TextualSegFormer\\lightning_logs\\version_27\\checkpoints\\transformer-epoch=09-val_loss=0.489-val_iou=0.16.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 48147: 'val_iou' was not in top 3
