In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!unzip /content/drive/MyDrive/PHOTOCL.zip -d /content/PHOTOCL

Archive:  /content/drive/MyDrive/PHOTOCL.zip
replace /content/PHOTOCL/dataset_original/Painting/painting_00001.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [3]:
!pip install timm albumentations --quiet


In [4]:
import os
import cv2
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from sklearn.metrics import f1_score, roc_auc_score


In [5]:
DATA_ROOT = "/content/PHOTOCL/dataset_original"
CLASSES = ["Photo", "Painting", "Schematics", "Sketch", "Text"]

samples = []
for cls in CLASSES:
    folder = os.path.join(DATA_ROOT, cls)
    for fname in os.listdir(folder):
        samples.append((os.path.join(folder, fname), cls))

random.shuffle(samples)


In [6]:
split = int(0.8 * len(samples))
train_samples = samples[:split]
val_samples = samples[split:]


In [7]:
class PhotoCLMultiTaskDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]

        image = cv2.imread(path)

        if image is None:
            new_idx = random.randint(0, len(self.samples) - 1)
            return self.__getitem__(new_idx)

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        target = {
            "is_photo": torch.tensor([label == "Photo"], dtype=torch.float32),
            "is_text": torch.tensor([label == "Text"], dtype=torch.float32),
            "is_art": torch.tensor([label in ["Painting", "Sketch"]], dtype=torch.float32),
            "is_schema": torch.tensor([label == "Schematics"], dtype=torch.float32),
        }

        if self.transform:
            image = self.transform(image=image)["image"]

        return image, target


In [8]:
train_tfms = A.Compose([
    A.Resize(300, 300),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.GaussianBlur(p=0.2),
    A.Normalize(),
    ToTensorV2()
])

val_tfms = A.Compose([
    A.Resize(300, 300),
    A.Normalize(),
    ToTensorV2()
])


In [9]:
train_ds = PhotoCLMultiTaskDataset(train_samples, train_tfms)
val_ds = PhotoCLMultiTaskDataset(val_samples, val_tfms)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)


In [10]:
class MultiTaskVisionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
            "efficientnet_b3",
            pretrained=True,
            num_classes=0
        )
        feat_dim = self.backbone.num_features

        self.heads = nn.ModuleDict({
            "is_photo": nn.Linear(feat_dim, 1),
            "is_text": nn.Linear(feat_dim, 1),
            "is_art": nn.Linear(feat_dim, 1),
            "is_schema": nn.Linear(feat_dim, 1),
        })

    def forward(self, x):
        feats = self.backbone(x)
        return {k: torch.sigmoid(h(feats)) for k, h in self.heads.items()}


In [11]:
weights = {
    "is_photo": 1.0,
    "is_text": 1.0,
    "is_art": 0.8,
    "is_schema": 0.8,
}

def multitask_loss(preds, targets):
    loss = 0
    for k in preds:
        loss += weights[k] * nn.functional.binary_cross_entropy(
            preds[k], targets[k].to(preds[k].device)
        )
    return loss


In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = MultiTaskVisionModel().to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

EPOCHS = 5


In [13]:
def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss = 0

    all_preds = {k: [] for k in weights}
    all_targets = {k: [] for k in weights}

    for images, targets in tqdm(loader):
        images = images.to(device)

        if train:
            optimizer.zero_grad()

        outputs = model(images)
        loss = multitask_loss(outputs, targets)

        if train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()

        for k in outputs:
            all_preds[k].append(outputs[k].detach().cpu().numpy())
            all_targets[k].append(targets[k].numpy())

    metrics = {}
    for k in weights:
        y_true = np.concatenate(all_targets[k])
        y_pred = np.concatenate(all_preds[k])
        metrics[k] = {
            "F1": f1_score(y_true > 0.5, y_pred > 0.5),
            "AUC": roc_auc_score(y_true, y_pred)
        }

    return total_loss / len(loader), metrics


In [14]:
def is_valid_image(path):
    try:
        img = cv2.imread(path)
        if img is None:
            return False
        _ = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return True
    except:
        return False


In [15]:
clean_samples = []
bad_samples = []

for path, label in samples:
    if is_valid_image(path):
        clean_samples.append((path, label))
    else:
        bad_samples.append(path)

print(f"Valid images: {len(clean_samples)}")
print(f"Corrupted images removed: {len(bad_samples)}")


Valid images: 41399
Corrupted images removed: 1


In [16]:
random.shuffle(clean_samples)

split = int(0.8 * len(clean_samples))
train_samples = clean_samples[:split]
val_samples = clean_samples[split:]


In [19]:
best_val_loss = float("inf")

for epoch in range(EPOCHS):
    train_loss, _ = run_epoch(train_loader, train=True)
    val_loss, val_metrics = run_epoch(val_loader, train=False)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train loss: {train_loss:.4f}")
    print(f"Val loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "val_loss": val_loss,
            },
            "/content/drive/MyDrive/best_multitask_model.pth"
        )
        print("✅ Best model saved")
    else:
        print("⚠️ Val loss worse than best")


100%|██████████| 2070/2070 [11:10<00:00,  3.09it/s]
100%|██████████| 518/518 [01:20<00:00,  6.47it/s]



Epoch 1/5
Train loss: 0.1953
Val loss: 0.0586
✅ Best model saved


100%|██████████| 2070/2070 [11:08<00:00,  3.10it/s]
100%|██████████| 518/518 [01:19<00:00,  6.49it/s]



Epoch 2/5
Train loss: 0.0977
Val loss: 0.0543
✅ Best model saved


100%|██████████| 2070/2070 [11:07<00:00,  3.10it/s]
100%|██████████| 518/518 [01:18<00:00,  6.59it/s]



Epoch 3/5
Train loss: 0.0757
Val loss: 0.0603
⚠️ Val loss worse than best


100%|██████████| 2070/2070 [11:08<00:00,  3.10it/s]
100%|██████████| 518/518 [01:19<00:00,  6.52it/s]



Epoch 4/5
Train loss: 0.0670
Val loss: 0.1951
⚠️ Val loss worse than best


100%|██████████| 2070/2070 [11:08<00:00,  3.10it/s]
100%|██████████| 518/518 [01:25<00:00,  6.05it/s]


Epoch 5/5
Train loss: 0.0494
Val loss: 0.0623
⚠️ Val loss worse than best



