In [8]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid

from models_v2 import *
from src.plots import *
#from src.plots_paper import *

In [9]:
import os
import glob
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

#from dinov2.models.vision_transformer import vit_large
#import dinov2.eval.segmentation.utils.colormaps as colormaps


In [30]:
# setup

#image_size = 518
image_size = 224
patch_size = 14
num_classes = 150

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# vit_ckpt = "/home/azywot/DINOv2/dinov2_vitl14_pretrain.pth"
# save_model_path = "/home/azywot/DINOv2/saved_models/seg_vitl14_no_registers.pth"
# output_seg_path = "/home/azywot/DINOv2/saved_models/ADE_val_00001112_segmented.png"
# custom_image_path = "/home/azywot/DINOv2/ADE_val_00001112.jpg"
# custom_mask_path = "/home/azywot/DINOv2/ADE_val_00001112_seg.png"
# data_root = "/home/azywot/DINOv2/ADE20K_2021_17_01/images/ADE"

data_root = "./images/ADE20K_2021_17_01/images/ADE"
save_model_path = "./first_try.pth"

In [53]:

class ADE20KSegmentation(Dataset):
    def __init__(self, root, split="training", image_size=518):
        self.image_files = sorted(glob.glob(os.path.join(root, split, "**", "*.jpg"), recursive=True))
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = img_path.replace(".jpg", "_seg.png")

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform(image)
        mask = np.array(mask)
        mask = Image.fromarray(mask).resize((self.image_size, self.image_size), resample=Image.NEAREST)
        mask = np.array(mask).astype(np.int64)

        mask[(mask != 255) & (mask > 149)] = 149
        mask = torch.from_numpy(mask)

        return image, mask

class LinearHead(nn.Module):
    def __init__(self, in_dim=1024, num_classes=150, patch_size=14, image_size=518):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.output_size = (image_size, image_size)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_dim, in_dim, kernel_size=2, stride=2),
            nn.GELU(),
            nn.Conv2d(in_dim, num_classes, kernel_size=1)
        )

    def forward(self, patch_tokens):
        B, N, C = patch_tokens.shape
        H = W = int(N ** 0.5)
        patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, H, W)
        logits = self.decoder(patch_tokens)
        logits = nn.functional.interpolate(logits, size=self.output_size, mode="bilinear", align_corners=False)
        return logits

class DeitSegModel(nn.Module):
    def __init__(self, vit, head, n_reg=0):
        super().__init__()
        self.vit = vit
        self.head = head
        self.n_reg = n_reg

    # def forward(self, x):
    #     features = self.vit.forward_features(x)
    #     return self.head(features["x_norm_patchtokens"])

    def forward(self, x):
        self.vit.forward_features(x)
        features = self.vit.block_output['final']
        if self.n_reg > 0:
            return self.head(features[:, 1 : - self.n_reg])
        else:
            return self.head(features[:, 1 :])
    
def compute_miou(preds, labels, num_classes, ignore_index=255):
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    cm = confusion_matrix(labels.flatten(), preds.flatten(), labels=list(range(num_classes)))
    intersection = np.diag(cm)
    union = np.sum(cm, axis=1) + np.sum(cm, axis=0) - np.diag(cm)
    iou = intersection / np.maximum(union, 1)
    return np.mean(iou), iou


In [54]:
print("Loading ViT-L/14...")
#vit = vit_large(patch_size=patch_size, img_size=image_size, init_values=1.0, block_chunks=0, num_register_tokens=0)

model_size = "small"
# we always used 224x224 images for deit
image_size = 224
side = 14
n_reg = 0

# Load the pretrained model
if model_size == "tiny":
    vit = deit_tiny_patch16_LS(
        pretrained=True, img_size=image_size, pretrained_21k=True
    )
elif model_size == "small":
    vit = deit_small_patch16_LS(
        pretrained=True, img_size=image_size, pretrained_21k=True
    )
elif model_size == "base":
    vit = deit_base_patch16_LS(
        pretrained=True, img_size=image_size, pretrained_21k=True
    )
elif model_size == "large":
    vit = deit_large_patch16_LS(
        pretrained=True, img_size=image_size, pretrained_21k=True
    )
else:
    raise ValueError("Invalid model size: choose from 'tiny', 'small', 'base', 'large'")

#vit.load_state_dict(torch.load(vit_ckpt, map_location="cpu"), strict=False)


Loading ViT-L/14...
******************** PRETRAINED 21k MODEL WILL BE USED


In [55]:

print("Initializing segmentation head...")
head = LinearHead(patch_size=patch_size, image_size=image_size)
model = DeitSegModel(vit, head).to(device)


Initializing segmentation head...


In [56]:

# Freeze all but last 2 blocks + head
for p in model.parameters():
    p.requires_grad = False
for block in model.vit.blocks[-2:]:
    for p in block.parameters():
        p.requires_grad = True
for p in model.head.parameters():
    p.requires_grad = True



In [57]:
train_dataset = ADE20KSegmentation(data_root, split="training", image_size=image_size)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)


In [60]:
print("Fine-tuning...")
model.train()
for epoch in range(5):
    total_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/5"):
        images, masks = images.to(device), masks.to(device)
        preds = model(images)
        loss = criterion(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), save_model_path)
print(f"Model saved to {save_model_path}")


Fine-tuning...


Epoch 1/5:   0%|          | 0/1599 [00:01<?, ?it/s]


RuntimeError: Given transposed=1, weight of size [1024, 1024, 2, 2], expected input[16, 384, 14, 14] to have 1024 channels, but got 384 channels instead

In [None]:

# ---- Evaluation on full validation set ----
print("Evaluating on validation set...")
val_dataset = ADE20KSegmentation(data_root, split="validation", image_size=image_size)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

model.load_state_dict(torch.load(save_model_path))
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, masks in tqdm(val_loader, desc="Evaluating"):
        images = images.to(device)
        logits = model(images)
        preds = logits.argmax(1).cpu()
        all_preds.append(preds)
        all_labels.append(masks)

all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

miou, iou_per_class = compute_miou(all_preds, all_labels, num_classes=num_classes)
print(f"Validation mIoU: {miou:.4f}")

# ---- Evaluation on custom test image ----
print("Evaluating custom test image...")
image = Image.open(custom_image_path).convert("RGB")
mask = Image.open(custom_mask_path).convert("L")

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
image_tensor = transform(image).unsqueeze(0).to(device)
mask_np = np.array(mask.resize((image_size, image_size), resample=Image.NEAREST)).astype(np.int64)
mask_np[(mask_np != 255) & (mask_np > 149)] = 149

with torch.no_grad():
    pred = model(image_tensor).argmax(1).squeeze(0).cpu().numpy()

colormap = np.array(colormaps.ADE20K_COLORMAP, dtype=np.uint8)
colored_pred = colormap[pred + 1]
Image.fromarray(colored_pred).save(output_seg_path)

miou_test, _ = compute_miou(pred, mask_np, num_classes=num_classes)
print(f"Test image mIoU: {miou_test:.4f}")
print(f"Segmentation saved at {output_seg_path}")
