# 🧪 Multi-Task Learning Assignment (Vision, PyTorch)

Auto-grading is enabled. Fill in the first block. Where indicated, **do not change variable names**.

In [None]:
# @title 1) Student Info & Config
# All code comments are in English.

# === MUST FILL ===
full_name = "Doe John"        # e.g., "Toshchev Alexander"
student_group = "11-111"      # e.g., "208"
assignment_id = "HW_MTL_01"
assert full_name != "Фамилия Имя", "Заполните full_name"
assert student_group != "Группа", "Заполните student_group"
print("✔ Student Info OK")

# Typical human accuracy (reference) for simple image classification could be high,
# but we use this only as a narrative target in reports.
HUMAN_ACCURACY = 98.0  # @param {type:"number"}

print("Student:", full_name)
print("Human reference accuracy (%):", HUMAN_ACCURACY)

from datetime import datetime, timezone, timedelta

# Windows for submissions (example):
start_at_iso = "2025-10-20T09:00-04:00"  # @param {type:"string"}
due_at_iso   = "2025-11-03T23:59-04:00"  # @param {type:"string"}
start_dt = datetime.fromisoformat(start_at_iso)
due_dt   = datetime.fromisoformat(due_at_iso)

# For the protocol: take current time as submission (or mtime of the notebook file)
import os
from datetime import datetime, timezone

# 📅 Add submission date based on file modification time
try:
    nb_path = __file__ if "__file__" in globals() else "MTL_Assignment.ipynb"
    mtime = os.path.getmtime(nb_path)
    submission_dt = datetime.fromtimestamp(mtime, tz=timezone.utc)
except Exception:
    submission_dt = datetime.utcnow().replace(tzinfo=timezone.utc)

def penalty_fraction(start_dt, due_dt, submission_dt):
    """Returns penalty fraction in [0..1].
    0 → no penalty (<= due_dt). Grows linearly from due_dt to due_dt + (due_dt - start_dt).
    Clamped to 1.0.
    """
    if submission_dt <= due_dt:
        return 0.0
    total = (due_dt - start_dt).total_seconds()
    late  = (submission_dt - due_dt).total_seconds()
    if total <= 0:
        return 1.0 if late > 0 else 0.0
    return min(1.0, max(0.0, late / total))

print(f"Window: {start_dt.isoformat()} — {due_dt.isoformat()} (UTC)")
print(f"Submitted at: {submission_dt.isoformat()} (UTC)")

# raw score accumulator
raw_score = 0.0
max_points = 100


In [None]:
# @title 2) Environment Check
import torch, torchvision
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())

In [None]:
# @title 3) Setup & Utilities
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import OneCycleLR
import random, time, json, math, os

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_WORKERS = 2
BATCH_SIZE = 128
EPOCHS_STAGE = 3  # quick stage; feel free to increase later
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

print("Device:", DEVICE)

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

@torch.no_grad()
def evaluate_mtl(trunk, head_cls, head_rot, dl):
    trunk.eval(); head_cls.eval(); head_rot.eval()
    total, correct_cls, correct_rot = 0, 0, 0
    for x, y_cls, y_rot in dl:
        x, y_cls, y_rot = x.to(DEVICE), y_cls.to(DEVICE), y_rot.to(DEVICE)
        f = trunk(x)
        logits_cls = head_cls(f)
        logits_rot = head_rot(f)
        pred_cls = logits_cls.argmax(1)
        pred_rot = logits_rot.argmax(1)
        correct_cls += (pred_cls == y_cls).sum().item()
        correct_rot += (pred_rot == y_rot).sum().item()
        total += y_cls.size(0)
    return correct_cls/total, correct_rot/total

def train_epoch_mtl(trunk, head_cls, head_rot, dl, opt, sched, criterion, weights=None, uncertainty=None):
    trunk.train(); head_cls.train(); head_rot.train()
    if weights is None: weights = (1.0, 1.0)
    w_cls, w_rot = weights
    if uncertainty is not None:
        log_vars = uncertainty  # nn.Parameter([logσ²_cls, logσ²_rot])

    for x, y_cls, y_rot in dl:
        x, y_cls, y_rot = x.to(DEVICE), y_cls.to(DEVICE), y_rot.to(DEVICE)
        opt.zero_grad()
        f = trunk(x)
        logits_cls = head_cls(f)
        logits_rot = head_rot(f)
        loss_cls = criterion(logits_cls, y_cls)
        loss_rot = criterion(logits_rot, y_rot)
        if uncertainty is None:
            loss = w_cls * loss_cls + w_rot * loss_rot
        else:
            # Uncertainty weighting (Kendall & Gal)
            loss = torch.exp(-log_vars[0]) * loss_cls + log_vars[0] \
                 + torch.exp(-log_vars[1]) * loss_rot + log_vars[1]
        loss.backward()
        nn.utils.clip_grad_norm_(list(trunk.parameters())+list(head_cls.parameters())+list(head_rot.parameters()), max_norm=5.0)
        opt.step()
        if sched is not None: sched.step()

In [None]:
# @title 4) Data: CIFAR10 + Rotation Task
# Multi-task setup: Task A = CIFAR10 classification (10 classes)
# Task B = Rotation prediction with 4 bins {0°, 90°, 180°, 270°}

from torchvision import transforms, datasets
import torch

IMG_SIZE = 224

base_train_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

base_val_tfms = transforms.Compose([
    transforms.Resize(IMG_SIZE + 32),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
])

train_base = datasets.CIFAR10(root='./data', train=True, download=True, transform=base_train_tfms)
val_base   = datasets.CIFAR10(root='./data', train=False, download=True, transform=base_val_tfms)

def rotate_batch(x, y_cls):
    # rotate by a random choice among [0, 90, 180, 270]
    b = x.size(0)
    rots = torch.randint(0, 4, (b,))  # 0..3
    # Apply rotation
    x_out = []
    for i in range(b):
        img = x[i]
        r = int(rots[i].item())
        x_out.append(torch.rot90(img, r, dims=(1,2)))
    return torch.stack(x_out, dim=0), y_cls, rots

class RotWrapper(Dataset):
    def __init__(self, base_ds):
        self.base = base_ds
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        img, y = self.base[idx]
        return img, y

def collate_with_rotate(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)
    y_cls = torch.tensor(ys, dtype=torch.long)
    x_rot, y_cls, y_rot = rotate_batch(x, y_cls)
    return x_rot, y_cls, y_rot

from torch.utils.data import DataLoader, Dataset
train_ds = RotWrapper(train_base)
val_ds   = RotWrapper(val_base)

BATCH_SIZE = 128
NUM_WORKERS = 2
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_rotate)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_rotate)

num_classes = 10
num_rot_bins = 4
print("Train/Val sizes:", len(train_ds), len(val_ds))

In [None]:
# @title 5) Model: Shared Trunk + Two Heads
import torch.nn as nn
from torchvision import models

class Trunk(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.features = nn.Sequential(*list(base.children())[:-1])  # [B, 512, 1, 1]
        self.out_dim = base.fc.in_features
    def forward(self, x):
        f = self.features(x)
        return f.view(f.size(0), -1)

class HeadCls(nn.Module):
    def __init__(self, in_dim, ncls):
        super().__init__()
        self.fc = nn.Linear(in_dim, ncls)
    def forward(self, f): return self.fc(f)

class HeadRot(nn.Module):
    def __init__(self, in_dim, nbins):
        super().__init__()
        self.fc = nn.Linear(in_dim, nbins)
    def forward(self, f): return self.fc(f)

trunk = Trunk().to(DEVICE)
head_cls = HeadCls(trunk.out_dim, num_classes).to(DEVICE)
head_rot = HeadRot(trunk.out_dim, num_rot_bins).to(DEVICE)

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Trainable params (total):", count_trainable_params(trunk)+count_trainable_params(head_cls)+count_trainable_params(head_rot))

## 6) Task 1 — **Hard Sharing, Equal Weights** (max 25 pts)
Train the shared trunk with two heads. Use equal loss weights (1.0, 1.0).

In [None]:
# @title Run Task 1 (Equal Weights)
import torch.nn as nn, torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

criterion = nn.CrossEntropyLoss()
opt = optim.AdamW(list(trunk.parameters())+list(head_cls.parameters())+list(head_rot.parameters()), lr=1e-3, weight_decay=1e-4)
sched = OneCycleLR(opt, max_lr=1e-3, epochs=EPOCHS_STAGE, steps_per_epoch=max(1, len(train_dl)))

def evaluate_mtl(trunk, head_cls, head_rot, dl):
    trunk.eval(); head_cls.eval(); head_rot.eval()
    total, correct_cls, correct_rot = 0, 0, 0
    with torch.no_grad():
        for x, y_cls, y_rot in dl:
            x, y_cls, y_rot = x.to(DEVICE), y_cls.to(DEVICE), y_rot.to(DEVICE)
            f = trunk(x)
            logits_cls = head_cls(f)
            logits_rot = head_rot(f)
            pred_cls = logits_cls.argmax(1)
            pred_rot = logits_rot.argmax(1)
            correct_cls += (pred_cls == y_cls).sum().item()
            correct_rot += (pred_rot == y_rot).sum().item()
            total += y_cls.size(0)
    return correct_cls/total, correct_rot/total

def train_epoch_mtl(trunk, head_cls, head_rot, dl, opt, sched, criterion, weights=None, uncertainty=None):
    trunk.train(); head_cls.train(); head_rot.train()
    if weights is None: weights = (1.0, 1.0)
    w_cls, w_rot = weights
    if uncertainty is not None:
        log_vars = uncertainty
    for x, y_cls, y_rot in dl:
        x, y_cls, y_rot = x.to(DEVICE), y_cls.to(DEVICE), y_rot.to(DEVICE)
        opt.zero_grad()
        f = trunk(x)
        logits_cls = head_cls(f)
        logits_rot = head_rot(f)
        loss_cls = criterion(logits_cls, y_cls)
        loss_rot = criterion(logits_rot, y_rot)
        if uncertainty is None:
            loss = w_cls * loss_cls + w_rot * loss_rot
        else:
            loss = torch.exp(-log_vars[0]) * loss_cls + log_vars[0] \
                 + torch.exp(-log_vars[1]) * loss_rot + log_vars[1]
        loss.backward()
        nn.utils.clip_grad_norm_(list(trunk.parameters())+list(head_cls.parameters())+list(head_rot.parameters()), max_norm=5.0)
        opt.step()
        if sched is not None: sched.step()

for ep in range(1, EPOCHS_STAGE+1):
    train_epoch_mtl(trunk, head_cls, head_rot, train_dl, opt, sched, criterion, weights=(1.0, 1.0))
    acc_cls, acc_rot = evaluate_mtl(trunk, head_cls, head_rot, val_dl)
    print(f"[T1] Epoch {ep}/{EPOCHS_STAGE} | val_acc_cls={acc_cls:.4f} val_acc_rot={acc_rot:.4f}")

t1_pts = 0
if acc_cls >= 0.55 and acc_rot >= 0.55: t1_pts = 25
elif acc_cls >= 0.45 and acc_rot >= 0.50: t1_pts = 18
elif acc_cls >= 0.35 and acc_rot >= 0.45: t1_pts = 12
elif acc_cls >= 0.30 and acc_rot >= 0.40: t1_pts = 8
else: t1_pts = 4

raw_score += t1_pts
print(f"Task1 → +{t1_pts} pts (raw_score={raw_score})")

## 7) Task 2 — **Manual Reweighting** (max 25 pts)
Re-run with manual weights to mitigate conflict.

In [None]:
# @title Run Task 2 (Manual Reweighting)
trunk2 = Trunk().to(DEVICE)
head_cls2 = HeadCls(trunk2.out_dim, num_classes).to(DEVICE)
head_rot2 = HeadRot(trunk2.out_dim, num_rot_bins).to(DEVICE)

criterion = nn.CrossEntropyLoss()
opt2 = optim.AdamW(list(trunk2.parameters())+list(head_cls2.parameters())+list(head_rot2.parameters()), lr=1e-3, weight_decay=1e-4)
sched2 = OneCycleLR(opt2, max_lr=1e-3, epochs=EPOCHS_STAGE, steps_per_epoch=max(1, len(train_dl)))

W_CLS = 1.0  # @param {type:"number"}
W_ROT = 0.5  # @param {type:"number"}

for ep in range(1, EPOCHS_STAGE+1):
    train_epoch_mtl(trunk2, head_cls2, head_rot2, train_dl, opt2, sched2, criterion, weights=(W_CLS, W_ROT))
    acc_cls2, acc_rot2 = evaluate_mtl(trunk2, head_cls2, head_rot2, val_dl)
    print(f"[T2] Epoch {ep}/{EPOCHS_STAGE} | val_acc_cls={acc_cls2:.4f} val_acc_rot={acc_rot2:.4f}")

t2_pts = 0
improve = (acc_cls2 >= acc_cls + 0.02) or (acc_rot2 >= acc_rot + 0.02)
not_collapse = (acc_cls2 >= acc_cls - 0.05) and (acc_rot2 >= acc_rot - 0.05)

if improve and not_collapse: t2_pts = 25
elif not_collapse: t2_pts = 16
else: t2_pts = 8

raw_score += t2_pts
print(f"Task2 → +{t2_pts} pts (raw_score={raw_score})")

## 8) Task 3 — **Uncertainty Weighting** (max 30 pts)
Learn log-variances to auto-balance tasks.

In [None]:
# @title Run Task 3 (Uncertainty Weighting)
trunk3 = Trunk().to(DEVICE)
head_cls3 = HeadCls(trunk3.out_dim, num_classes).to(DEVICE)
head_rot3 = HeadRot(trunk3.out_dim, num_rot_bins).to(DEVICE)

log_vars = nn.Parameter(torch.zeros(2, device=DEVICE))  # [logσ²_cls, logσ²_rot]
criterion = nn.CrossEntropyLoss()
opt3 = optim.AdamW(list(trunk3.parameters())+list(head_cls3.parameters())+list(head_rot3.parameters())+[log_vars], lr=1e-3, weight_decay=1e-4)
sched3 = OneCycleLR(opt3, max_lr=1e-3, epochs=EPOCHS_STAGE, steps_per_epoch=max(1, len(train_dl)))

for ep in range(1, EPOCHS_STAGE+1):
    train_epoch_mtl(trunk3, head_cls3, head_rot3, train_dl, opt3, sched3, criterion, uncertainty=log_vars)
    acc_cls3, acc_rot3 = evaluate_mtl(trunk3, head_cls3, head_rot3, val_dl)
    print(f"[T3] Epoch {ep}/{EPOCHS_STAGE} | val_acc_cls={acc_cls3:.4f} val_acc_rot={acc_rot3:.4f} | log_vars={log_vars.detach().cpu().numpy()}")

t3_pts = 0
if (acc_cls3 >= max(acc_cls, 0.55)) and (acc_rot3 >= max(acc_rot, 0.55)):
    t3_pts = 30
elif (acc_cls3 >= acc_cls or acc_rot3 >= acc_rot):
    t3_pts = 20
else:
    t3_pts = 10

raw_score += t3_pts
print(f"Task3 → +{t3_pts} pts (raw_score={raw_score})")

## 9) Task 4 — **Soft Sharing (Sketch)** — Bonus up to +10

In [None]:
# @title Bonus (manual)
bonus_points = 0.0  # set 0..10 after implementing a soft-sharing variant and documenting
raw_score = min(100.0, raw_score + float(bonus_points))
print("raw_score (with optional bonus, capped at 100) →", raw_score)

In [None]:
# @title 10) Summary of Results
summary = {
    "task1_equal_weights": {"acc_cls": float(acc_cls),  "acc_rot": float(acc_rot)},
    "task2_manual":        {"acc_cls": float(acc_cls2), "acc_rot": float(acc_rot2)},
    "task3_uncertainty":   {"acc_cls": float(acc_cls3), "acc_rot": float(acc_rot3)},
    "epochs_per_stage": int(EPOCHS_STAGE),
    "device": DEVICE,
}
print(json.dumps(summary, indent=2))

In [None]:
# @title 11) Finalize & Grade (Penalty + JSON)
import json

# apply penalty
try:
    pf = penalty_fraction(start_dt, due_dt, submission_dt)
except NameError:
    from datetime import timezone
    pf = 0.0
# ✅ Final score
max_points=100
final_score = max(0.0, raw_score * (1.0 - min(1.0, pf)))

print(f"Сырой балл: {raw_score}/{max_points}")
print(f"Штраф (доля): {pf:.4f}")
print(f"Итоговый балл после штрафа: {final_score:.2f}/{max_points}")

# Last line — JSON for the harness
final = {
    "name": full_name,
    "group": student_group,
    "assignment": assignment_id,
    "score": float(final_score)
}

print(json.dumps(final, ensure_ascii=False))