In [None]:
import warnings
warnings.filterwarnings('ignore')

import torch
import pandas as pd
import numpy as np
import torchvision
import cv2
import pytorch_lightning as pl
import torchmetrics as tm

from torch.utils.data import Dataset, DataLoader
from ZeroShotDataset import ZeroShotDataset
from params import *
from transformers import CLIPProcessor, CLIPModel
from LossFunc import *
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.utils.data import random_split
from CLIPConditionedSegFormerModel import CLIPConditionedSegFormer

In [None]:
tests_params = {
    'Unseen' : {
        'filter_unseen' : True,
        'filter_seen' : False
    },
    'Seen' : {
        'filter_unseen' : False,
        'filter_seen' : True
    },
    'All' : {
        'filter_unseen' : False,
        'filter_seen' : False
    }
}

In [None]:
def test(model, params, dataset_params):
    dataset_params['filter_unseen'] = params['filter_unseen']
    dataset_params['filter_seen'] = params['filter_seen']
    ds = ZeroShotDataset(**dataset_params)

    loader = DataLoader(ds, batch_size=TrainParams.BATCH_SIZE, num_workers=1, shuffle=True, collate_fn=ds.collate_fn)
    trainer = pl.Trainer(accelerator='gpu', max_epochs=1)

    results = trainer.test(model, dataloaders=loader)
    return results

In [None]:
def test_model(model, dataset_params, tests_params):
    df = pd.DataFrame(columns=['test', 'acc', 'dice', 'miou', 'f1'])
    for test_name, params in tests_params.items():
        results = test(model, params, dataset_params)
        df = df.append({
            'test' : test_name,
            'acc' : results[0]['test_acc'],
            'dice' : results[0]['test_dice'],
            'miou' : results[0]['test_iou'],
            'f1' : results[0]['test_f1'],
        }, ignore_index=True)

    return df

In [None]:
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch16')

In [None]:
model = CLIPConditionedSegFormer()
model.load_state_dict(torch.load("lightning_logs/version_13/checkpoints/transformer-epoch=03-val_loss=0.438-val_iou=0.13.ckpt")["state_dict"])
model = model.eval()

In [None]:
df = pd.read_csv('ProcessedDatasetStuff512/csv/val.csv')

In [None]:
dataset_params = {
    'df': df,
    'image_folder': TrainParams.DATASET_IMAGE_FOLDER_VAL,
    'mask_folder': TrainParams.DATASET_MASK_FOLDER_VAL,
    'image_size': TrainParams.IMAGE_DIM,
    'mask_size': TrainParams.MASK_SIZE,
    'templates': ['{}'],
    'unseen_classes': TrainParams.UNSEEN_CLASSES, 
    'image_processor': clip_processor, 
    'tokenizer': clip_processor.tokenizer
}

In [18]:
test_df = test_model(model, dataset_params, tests_params)

In [None]:
unseen_ds_params = dataset_params.copy()
unseen_ds_params['filter_unseen'] = True
unseen_ds_params['filter_seen'] = False

unseen_ds = ZeroShotDataset(**unseen_ds_params)

In [None]:
unseen_loader = DataLoader(unseen_ds, batch_size=TrainParams.BATCH_SIZE, shuffle=True, collate_fn=unseen_ds.collate_fn)
x, x_c, condition, y = next(iter(unseen_loader))
# condition = condition.unsqueeze(0)
print(x.shape, x_c.shape, condition.shape, y.shape)
# condition = clip_processor.tokenizer.encode("a photo of a tv")
# condition = torch.tensor(condition).long()
# pred = model(x.unsqueeze(0), x_c.unsqueeze(0), condition.unsqueeze(0))
pred = model(x, x_c, condition)
pred = torch.sigmoid(pred)

In [None]:
clip_processor.tokenizer.decode(condition)

In [None]:
import matplotlib.pyplot as plt

In [None]:

plt.imshow(image.permute(1, 2, 0))

In [None]:
plt.subplot(1, 2, 1)
# set range to [0,1] for matplotlib
plt.imshow(mask[0], vmin=0, vmax=1)
plt.subplot(1, 2, 2)
plt.imshow(pred[0][0].detach().numpy(), vmin=0, vmax=1)