## Download and setup project and data

In [None]:
!pip -q install openai-clip
!pip -q install gdown

In [None]:
!git clone 'https://github.com/ararchieves/house.git'
%cd house

In [None]:
import gdown
data_link = 'https://drive.google.com/file/d/1-zD25r3PPt4fDJBPKbFy6Qdemeo039lG/view?usp=sharing' 
gdown.download(url=data_link, output='data_cat.zip', quiet=False, fuzzy=True)

print("Unzipping data")
!unzip -q 'data_cat' -d '.'
!mv 'data_cat' 'data'
!rm 'data_cat.zip'
print("Data Unzipping Complete!")

## Code

## Imports

In [None]:
import os
import yaml
import random

from models.clipseg import CLIPDensePredT

import torch
import torchvision
from torch import nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image, ImageReadMode

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

## Config Variables

In [None]:
# Data
ROOT_DIR = './data'
BATCH_SIZE = 12
TRANSFORMS = None
SHUFFLE = True
SEED = 42
NUM_WORKERS = 2
PIN_MEMORY = True
# Training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-3
EPOCHS = 50


## Data Loading and Visualization

In [None]:
class ChineseCityDataset(Dataset):
    def __init__(self, root_dir='./data', split='train', transform=None):
        super().__init__()

        if split == 'train':
            self.base_dir = f'{root_dir}/train'
        elif split == 'test':
            self.base_dir = f'{root_dir}/test'
        elif split == 'val':
            self.base_dir = f'{root_dir}/val'
        else:
            raise Exception(f"Invalid split parameter! '{split}' not in ['train', 'test', 'val']")

        # variables
        self.image_dir = f'{self.base_dir}/images'
        self.masks_dir = f'{self.base_dir}/masks'

        self.images = os.listdir(self.image_dir)
        self.mask_dirs = os.listdir(self.masks_dir)

        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]

        img = read_image(f'{self.base_dir}/images/{image_name}', mode=ImageReadMode.RGB) / 255

        # Remove the view id from the image name: Image Name = {image_id}_{view}
        masks_dir_id = image_name.split('_')[0]

        # Masks info
        ## 0 = Ground Floor: 0
        ## 1-4 =  Short Building or a house: 1
        ## 5-10 = Meduim Building: 2
        ## 11+ = Tall Building: 3

        masks = []
        for mask_number in range(4):
            mask = read_image(f'{self.masks_dir}/{masks_dir_id}/{mask_number}.png', mode=ImageReadMode.GRAY) / 255
            masks.append(mask)


        if self.transform:
            img = self.transform(img)
            masks = [self.transform(mask) for mask in masks]

        return img, masks

In [None]:
train_data = ChineseCityDataset(root_dir=ROOT_DIR, split='train')
trainloader = DataLoader(train_data,
                         batch_size=BATCH_SIZE,
                         shuffle=SHUFFLE,
                         generator=torch.Generator().manual_seed(SEED),
                         num_workers=NUM_WORKERS,
                         pin_memory=PIN_MEMORY
                         )
print(f"Length of trainloader is: {len(trainloader)}")

In [None]:
val_data = ChineseCityDataset(root_dir=ROOT_DIR, split='val')
valloader = DataLoader(val_data,
                         batch_size=BATCH_SIZE,
                         shuffle=SHUFFLE,
                         generator=torch.Generator().manual_seed(SEED),
                         num_workers=NUM_WORKERS,
                         pin_memory=PIN_MEMORY
                         )
print(f"Length of valloader is: {len(valloader)}")

In [None]:
images, [mask0, mask1, mask2, mask3] = next(iter(trainloader))

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(20, 20))

axes[0,0].set_title("View")
axes[0,1].set_title("Ground Floor")
axes[0,2].set_title("Short Buildings")
axes[0,3].set_title("Mudium Buildings")
axes[0,4].set_title("Tall Buildings")

for idx in range(5):
    axes[idx, 0].imshow(images[idx].permute(1,2,0))
    axes[idx, 1].imshow(mask0[idx].permute(1,2,0))
    axes[idx, 2].imshow(mask1[idx].permute(1,2,0))
    axes[idx, 3].imshow(mask2[idx].permute(1,2,0))
    axes[idx, 4].imshow(mask3[idx].permute(1,2,0))

plt.show()

## Training Setup 

In [None]:
model = CLIPDensePredT(version='ViT-B/16', complex_trans_conv=True).to(DEVICE)

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of trainable parameters are: {trainable_params:,}")

In [None]:
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, ture, pred):
        return torch.sqrt(self.mse(ture, pred))

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = RMSELoss()

In [None]:
with open(f'config/prompts.yaml', 'r') as f:
    _prompts = yaml.safe_load(f)


def get_random_prompts(mask_type, n_prompts, prompts=_prompts):
    valid_mask_types = [0,1,2,3]
    assert mask_type in valid_mask_types, f"Invalid mask_type! {mask_type} not in {valid_mask_types}"

    key_list = ["PromptsMask0", "PromptsMask1", "PromptsMask2", "PromptsMask3"]

    random_prompts = []
    for i in range(n_prompts):
        random_prompt = random.choice(prompts[key_list[mask_type]])
        random_prompts.append(random_prompt)

    return random_prompts


In [None]:
train_loss_list = []
val_loss_list = []


for epoch in range(EPOCHS):
    epoch_train_loss = 0.0
    # Training
    model.train()
    for idx, (images, masks) in tqdm(enumerate(trainloader), total=len(trainloader), desc=f'Training - Epoch {epoch+1}/{EPOCHS}: '):
        images = images.to(DEVICE, non_blocking=True)
        masks = [mask.to(DEVICE, non_blocking=True) for mask in masks]

        # Prepare the images and masks for forward pass
        
        _images = images.repeat(4, 1, 1, 1)
        _masks = torch.cat(masks, dim=0)
        prompts = []
        for i in range(4):
            prompts += get_random_prompts(i, images.shape[0])

        optimizer.zero_grad()

        pred_mask = model(_images, prompts)[0]
        loss = criterion(pred_mask, _masks)

        loss.backward()
        optimizer.step()
        
        epoch_train_loss += loss.item()
    train_loss_list.append(epoch_train_loss)
    
    # Validation 
    epoch_val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for idx, (images, masks) in tqdm(enumerate(valloader), total=len(valloader), desc=f'Validation - Epoch {epoch+1}/{EPOCHS}: '):
            images = images.to(DEVICE, non_blocking=True)
            masks = [mask.to(DEVICE, non_blocking=True) for mask in masks]

            _images = images.repeat(4, 1, 1, 1)
            _masks = torch.cat(masks, dim=0)
            prompts = []
            for i in range(4):
                prompts += get_random_prompts(i, images.shape[0])
                
            pred_mask = model(_images, prompts)[0]
            loss = criterion(pred_mask, _masks)
            
            epoch_val_loss += loss.item()
        val_loss_list.append(epoch_train_loss)
    
    print(f"Trianing Loss: {epoch_train_loss:.4f} - Validation Loss: {epoch_val_loss:.4f}")



In [None]:
min_y = min(min(train_loss_list), min(val_loss_list))
max_y = max(max(train_loss_list), max(val_loss_list))
plt.plot(train_loss_list, label='Training Loss', color='r')
plt.plot(val_loss_list, label='Validation Loss', color='b')
plt.ylim(min_y, max_y)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

In [None]:
torch.save(model.state_dict(), "trained_model.pth")