## Zero Shot Segmentation using OPEN AI's CLIP

In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import clip
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
from encoder import Encoder
from decoder import Decoder
from dataloader import PhraseCutDataset_

In [None]:
encoder = Encoder()

In [None]:
print(encoder)

In [None]:
data = PhraseCutDataset_('val')
val = DataLoader(data, batch_size=1, shuffle=False)

data = PhraseCutDataset_('train')
train = DataLoader(data, batch_size=1, shuffle=True)

### Decoder

In [12]:
decoder = Decoder(extract_layers=[3, 6, 9], mha_heads=4, reduce_dim=64, cond_layer=3)

In [None]:
decoder

### DataLoader

In [15]:
train

In [16]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(decoder.parameters(), lr=3e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 307480, 1e-5)

### Training

In [None]:
loss_hist = []
iter_n = []
epochs = 10

for epoch in range(epochs):
    
    for i, (phrase, input_img, output_img, id_) in enumerate(train):
    
        accuracies_iou = 0

        if(len(input_img.shape) != 4):
            continue

        encodings = encoder(transforms.ToPILImage()(input_img[0].permute(2, 0, 1)).convert("RGB"), phrase[0])

        output = decoder(encodings)

        loss = criterion(output[0][0], output_img[0])

        pred = (torch.sigmoid(output[0][0]))#>0.3).int()

        tp = torch.sum(pred*output_img[0])
        fp = torch.sum(pred*(1. - output_img[0]))
        fn = torch.sum((1. - pred)*output_img[0])
        accuracies_iou *= i
        accuracies_iou += (tp/(tp+fp+fn))
        accuracies_iou /= i + 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        if (i+1)%(3000) == 0:
            loss_hist.append(loss.item())
            iter_n.append(epoch*30748 + i)
            print(f"Epoch : {epoch + 1}, Iteration : {i+1}, Loss : {loss.item()}")
            print(f"IOU Accuracy : {100*(accuracies_iou) :.5f}%")
            

In [18]:
torch.save(decoder.state_dict(), "ClipSeg.pth")

In [None]:
plt.plot(iter_n, loss_hist)
plt.title("Loss vs Iterations")

### Model Evaluation

In [None]:
with torch.no_grad():
    pixel = []
    iou = [] 
    
    for i, (phrase, input_img, output_img, _) in enumerate(val):
        
        if(len(input_img.shape) != 4):
            continue

        encodings = encoder(transforms.ToPILImage()(input_img[0].permute(2, 0, 1)).convert("RGB"), phrase[0])

        output = decoder(encodings)
        
        pred = ((torch.sigmoid(output[0][0])) > 0.3).float()
        
        pixel.append(torch.sum((pred) == output_img[0])/(224*224))
        
        tp = torch.sum(pred*output_img[0])
        fp = torch.sum(pred*(1. - output_img[0]))
        fn = torch.sum((1. - pred)*output_img[0])

        iou.append(tp/(tp+fp+fn))
        
    print(f"Pixel-by-Pixel Accuracy : {100*sum(pixel)/len(pixel) :.5f}%")
    print(f"IOU Accuracy : {100*sum(iou)/len(iou) :.5f}%")

# Visualization

In [None]:
model = Decoder(extract_layers=[3, 6, 9], mha_heads=4, cond_layer=3, reduce_dim=64)
model.load_state_dict(torch.load('__pycache__/ClipSeg.pth'))
for param in model.parameters():
    param.requires_grad_(False)

In [None]:
img=Image.open("PhraseCutDataset/data/VGPhraseCut_v0/images_val/2339423.jpg")

img = img.resize((224,224))
img.show()

encodings = encoder(img, "stack of gifts")
output = model(encodings)

pred=(torch.sigmoid(output)>0.3).float()

img = transforms.ToPILImage()(pred[0]).convert("L")
img.show()