In [2]:
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from datasets import load_from_disk
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cpu


In [None]:
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 40
IMG_SIZE = 224
PATCH_SIZE = 16

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

In [4]:
TRAIN_VAL_PATH = "processed_bird_data"
TEST_PATH = "processed_bird_test_data"

print("Loading train/val dataset from:", TRAIN_VAL_PATH)
full_ds = load_from_disk(TRAIN_VAL_PATH)
train_hf = full_ds["train"]
val_hf = full_ds["validation"]

print("Train size:", len(train_hf))
print("Val size:", len(val_hf))

print("\nLoading test dataset from:", TEST_PATH)
test_hf = load_from_disk(TEST_PATH)
print("Test size:", len(test_hf))

# attributes
ATTR_PATH = "data/attributes.npy"
attributes = np.load(ATTR_PATH)
NUM_CLASSES = attributes.shape[0]
NUM_ATTR = attributes.shape[1]

print("\nAttributes shape:", attributes.shape)
print("NUM_CLASSES:", NUM_CLASSES, "| NUM_ATTR:", NUM_ATTR)

Loading train/val dataset from: processed_bird_data
Train size: 3337
Val size: 589

Loading test dataset from: processed_bird_test_data
Test size: 4000

Attributes shape: (200, 312)
NUM_CLASSES: 200 | NUM_ATTR: 312


In [5]:
from torchvision.transforms import InterpolationMode

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

In [6]:
class BirdTrainDataset(Dataset):
    def __init__(self, hf_dataset, attributes, transform=None):
        self.ds = hf_dataset
        self.attributes = attributes.astype("float32")
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        label = int(item["label"]) 

        if self.transform is not None:
            img = self.transform(img)

        attr_vec = self.attributes[label]
        attr_vec = torch.from_numpy(attr_vec)

        return img, attr_vec, label


class BirdTestDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.ds = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        img_id = int(item["id"])
        return img, img_id


train_dataset = BirdTrainDataset(train_hf, attributes, transform=train_transform)
val_dataset   = BirdTrainDataset(val_hf,   attributes, transform=eval_transform)
test_dataset  = BirdTestDataset(test_hf,   transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [7]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size//patch_size
        self.num_patches = self.grid_size**2

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x [B,3,H,W]
        x = self.proj(x) #[B,embed_dim,H/P,W/P]
        x = x.flatten(2) #[B,embed_dim,num_patches]
        x = x.transpose(1, 2) #[B,num_patches,embed_dim]
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        #x [B,N,D]
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class SimpleViTWithAttributes(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 num_classes=200, num_attr=312, embed_dim=192, depth=6,
                 num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, drop)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        self.head_class = nn.Linear(embed_dim, num_classes)
        # gÅ‚owa atrybutowa
        self.head_attr  = nn.Linear(embed_dim, num_attr)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head_class.weight, std=0.02)
        nn.init.trunc_normal_(self.head_attr.weight, std=0.02)

    def forward(self, x):
        # x:[B,3,224,224]
        B = x.size(0)
        x = self.patch_embed(x) #[B,N,D]

        cls_tokens = self.cls_token.expand(B,-1,-1) #[B,1,D]
        x = torch.cat((cls_tokens, x), dim=1) #[B,1+N,D]
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0] #[B,D]

        logits = self.head_class(cls)
        attr_pred = self.head_attr(cls)
        return logits, attr_pred


model = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=0.1
).to(DEVICE)

model

SimpleViTWithAttributes(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerEncoderBlock(
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=192, out_features=768, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=768, out_features=192, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head_class): Linear(in_features=192, out_features=200, bias=True)
  (head_attr): Linear(in_features=192, out_features=312, b

In [8]:
criterion_class = nn.CrossEntropyLoss()
criterion_attr  = nn.MSELoss()
LAMBDA_ATTR = 0.05

optimizer = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

In [9]:
def train_one_epoch(epoch_idx):
    model.train()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    for batch_idx, (imgs, attr_targets, labels) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        optimizer.zero_grad()

        logits, attr_pred = model(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_cls  += loss_cls.item() * imgs.size(0)
        total_attr += loss_attr.item() * imgs.size(0)

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        samples += imgs.size(0)

        if batch_idx % 20 == 0:
            print(f"[Epoch {epoch_idx}] Batch {batch_idx}/{len(train_loader)} "
                  f"loss={loss.item():.4f}")

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc


def evaluate():
    model.eval()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, attr_targets, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            attr_targets = attr_targets.to(DEVICE)

            logits, attr_pred = model(imgs)

            loss_cls = criterion_class(logits, labels)
            loss_attr = criterion_attr(attr_pred, attr_targets)
            loss = loss_cls + LAMBDA_ATTR * loss_attr

            total_loss += loss.item() * imgs.size(0)
            total_cls  += loss_cls.item() * imgs.size(0)
            total_attr += loss_attr.item() * imgs.size(0)

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            samples += imgs.size(0)

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc

In [None]:
best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")

    train_loss, train_cls, train_attr, train_acc = train_one_epoch(epoch)
    val_loss, val_cls, val_attr, val_acc = evaluate()

    scheduler.step()

    print(
        f"Train:   loss={train_loss:.4f} (cls={train_cls:.4f}, attr={train_attr:.4f}), acc={train_acc:.4f}"
        f"Val:   loss={val_loss:.4f} (cls={val_cls:.4f}, attr={val_attr:.4f}), acc={val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "vit_best_model.pth")
        print("Best VIT model saved")


Epoch 1/25
[Epoch 1] Batch 0/105 loss=5.4359
[Epoch 1] Batch 20/105 loss=5.3092
[Epoch 1] Batch 40/105 loss=5.1971


In [1]:
# todo: test data