In [1]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=7),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    transforms.ToTensor()
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


In [2]:
import os
import json

image_root = r"D:\CompCars\image"

year_map = {}

for brand in os.listdir(image_root):
    brand_path = os.path.join(image_root, brand)
    if not os.path.isdir(brand_path):
        continue

    for model in os.listdir(brand_path):
        model_path = os.path.join(brand_path, model)
        if not os.path.isdir(model_path):
            continue

        for year in os.listdir(model_path):
            year_path = os.path.join(model_path, year)
            if not os.path.isdir(year_path):
                continue

            if not year.isdigit():
                continue

            year_int = int(year)

            # Only keep realistic car years
            if year_int < 1980 or year_int > 2025:
                continue

            for fname in os.listdir(year_path):
                if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                    year_map[fname] = year_int

print("Total labeled images:", len(year_map))
print("Unique years:", sorted(set(year_map.values())))

with open("year_labels.json", "w") as f:
    json.dump(year_map, f, indent=2)


Total labeled images: 136186
Unique years: [1993, 2000, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016]


In [3]:
print(year_map["043602fa21a621.jpg"])

2013


In [4]:
with open("year_labels.json") as f:
    year_map = json.load(f)

years = sorted(set(year_map.values()))
print(years)
print(len(years))


[1993, 2000, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016]
17


In [5]:
from torchvision.datasets import ImageFolder

In [6]:
class CarDatasetWithYear(ImageFolder):
    def __init__(self, root, transform, year_map, year_to_idx):
        super().__init__(root, transform=transform)
        self.year_map = year_map
        self.year_to_idx = year_to_idx

    def __getitem__(self, index):
        path, model_label = self.samples[index]
        fname = os.path.basename(path)

        if fname not in self.year_map:
            return self.__getitem__((index + 1) % len(self))

        year = self.year_map[fname]
        year_label = self.year_to_idx[year]

        image = self.loader(path)
        if self.transform is not None:
            image = self.transform(image)

        return image, model_label, year_label


In [7]:
import json

with open("year_labels.json") as f:
    year_map = json.load(f)

years = sorted(set(year_map.values()))
year_to_idx = {year: i for i, year in enumerate(years)}
idx_to_year = {i: year for year, i in year_to_idx.items()}

print("Years:", years)
print("Number of year classes:", len(years))


Years: [1993, 2000, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016]
Number of year classes: 17


In [8]:
train_dataset = CarDatasetWithYear("../dataset/train", train_transform, year_map, year_to_idx)
val_dataset   = CarDatasetWithYear("../dataset/val",   eval_transform,  year_map, year_to_idx)
test_dataset  = CarDatasetWithYear("../dataset/test",  eval_transform,  year_map, year_to_idx)

In [9]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_MODELS = 429   # your trained model classes
NUM_YEARS  = len(years)

# Recreate backbone
backbone = resnet50(weights=None)
backbone.fc = nn.Identity()

# Recreate your model head (exact same as before)
class ModelHead(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model_head = ModelHead(NUM_MODELS)

# Load trained weights
state = torch.load("saved_models/resnet50_compcars_20260111_154612.pth", map_location=device)

# We need to load backbone + model_head weights
backbone.load_state_dict({k.replace("backbone.", ""): v for k, v in state.items() if k.startswith("backbone.")})
model_head.load_state_dict({k.replace("classifier.", "net."): v for k, v in state.items() if k.startswith("classifier.")})

backbone = backbone.to(device)
model_head = model_head.to(device)


In [10]:
class MultiTaskCarNet(nn.Module):
    def __init__(self, backbone, model_head, num_years):
        super().__init__()
        self.backbone = backbone
        self.model_head = model_head
        self.year_head = nn.Linear(2048, num_years)

    def forward(self, x):
        feats = self.backbone(x)
        model_logits = self.model_head(feats)
        year_logits  = self.year_head(feats)
        return model_logits, year_logits

multi_net = MultiTaskCarNet(backbone, model_head, NUM_YEARS).to(device)


In [11]:
for p in multi_net.backbone.parameters():
    p.requires_grad = False

for p in multi_net.model_head.parameters():
    p.requires_grad = False

for p in multi_net.year_head.parameters():
    p.requires_grad = True


In [12]:
sum(p.requires_grad for p in multi_net.parameters())


2

In [13]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)


In [14]:
year_criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    multi_net.year_head.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)


In [15]:
missing = 0
for path, _ in train_dataset.samples:
    fname = os.path.basename(path)
    if fname not in year_map:
        missing += 1

print("Train images without year:", missing, "/", len(train_dataset))


Train images without year: 8 / 10533


In [16]:
missing_val = 0
for path, _ in val_dataset.samples:
    fname = os.path.basename(path)
    if fname not in year_map:
        missing_val += 1

print("Val images without year:", missing_val, "/", len(val_dataset))


Val images without year: 3 / 2416


In [17]:
print("Year classes:", len(years))
print("Min year:", min(years))
print("Max year:", max(years))

Year classes: 17
Min year: 1993
Max year: 2016


In [18]:
from tqdm import tqdm

def train_year_epoch(model, loader):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, _, years in tqdm(loader, desc="Train-Year"):
        images = images.to(device)
        years  = years.to(device)

        optimizer.zero_grad()

        _, year_logits = model(images)
        loss = year_criterion(year_logits, years)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = year_logits.argmax(1)
        correct += (preds == years).sum().item()
        total += years.size(0)

    return total_loss / total, correct / total


def val_year_epoch(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, _, years in tqdm(loader, desc="Val-Year"):
            images = images.to(device)
            years  = years.to(device)

            _, year_logits = model(images)
            loss = year_criterion(year_logits, years)

            total_loss += loss.item() * images.size(0)
            preds = year_logits.argmax(1)
            correct += (preds == years).sum().item()
            total += years.size(0)

    return total_loss / total, correct / total


In [19]:
class EarlyStopping:
    def __init__(self, patience=3):
        self.best = None
        self.counter = 0
        self.best_state = None
        self.patience = patience

    def step(self, val_acc, model):
        if self.best is None or val_acc > self.best:
            self.best = val_acc
            self.counter = 0
            self.best_state = {
                k: v.detach().cpu().clone()
                for k, v in model.state_dict().items()
            }
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience


In [20]:
from time import time
t0 = time()
images, _, years = next(iter(train_loader))
print("First batch loaded in:", time() - t0, "seconds")


First batch loaded in: 0.8400783538818359 seconds


In [22]:
early = EarlyStopping(patience=3)

EPOCHS = 15

for epoch in range(EPOCHS):
    train_loss, train_acc = train_year_epoch(multi_net, train_loader)
    val_loss, val_acc = val_year_epoch(multi_net, val_loader)

    print(f"Epoch {epoch+1}: Train Acc={train_acc:.3f} | Val Acc={val_acc:.3f}")

    if early.step(val_acc, multi_net):
        print("Early stopping triggered")
        multi_net.load_state_dict(early.best_state)
        break


Train-Year: 100%|██████████| 330/330 [05:12<00:00,  1.06it/s]
Val-Year: 100%|██████████| 76/76 [00:52<00:00,  1.44it/s]


Epoch 1: Train Acc=0.319 | Val Acc=0.380


Train-Year: 100%|██████████| 330/330 [03:07<00:00,  1.76it/s]
Val-Year: 100%|██████████| 76/76 [00:25<00:00,  2.96it/s]


Epoch 2: Train Acc=0.408 | Val Acc=0.381


Train-Year: 100%|██████████| 330/330 [02:43<00:00,  2.01it/s]
Val-Year: 100%|██████████| 76/76 [00:27<00:00,  2.80it/s]


Epoch 3: Train Acc=0.437 | Val Acc=0.421


Train-Year: 100%|██████████| 330/330 [02:51<00:00,  1.92it/s]
Val-Year: 100%|██████████| 76/76 [00:25<00:00,  2.98it/s]


Epoch 4: Train Acc=0.456 | Val Acc=0.397


Train-Year: 100%|██████████| 330/330 [02:41<00:00,  2.04it/s]
Val-Year: 100%|██████████| 76/76 [00:26<00:00,  2.83it/s]


Epoch 5: Train Acc=0.476 | Val Acc=0.409


Train-Year: 100%|██████████| 330/330 [02:52<00:00,  1.91it/s]
Val-Year: 100%|██████████| 76/76 [00:26<00:00,  2.83it/s]


Epoch 6: Train Acc=0.486 | Val Acc=0.444


Train-Year: 100%|██████████| 330/330 [02:53<00:00,  1.90it/s]
Val-Year: 100%|██████████| 76/76 [00:26<00:00,  2.82it/s]


Epoch 7: Train Acc=0.495 | Val Acc=0.429


Train-Year: 100%|██████████| 330/330 [02:52<00:00,  1.92it/s]
Val-Year: 100%|██████████| 76/76 [00:26<00:00,  2.84it/s]


Epoch 8: Train Acc=0.506 | Val Acc=0.431


Train-Year: 100%|██████████| 330/330 [02:52<00:00,  1.91it/s]
Val-Year: 100%|██████████| 76/76 [00:27<00:00,  2.81it/s]

Epoch 9: Train Acc=0.507 | Val Acc=0.437
Early stopping triggered





In [None]:
import torch

torch.save({
    "backbone": multi_net.backbone.state_dict(),
    "model_head": multi_net.model_head.state_dict(),
    "year_head": multi_net.year_head.state_dict(),
    "year_to_idx": year_to_idx,
    "idx_to_year": idx_to_year
}, "saved_models/multitask_car_net.pth")

print("Multi-task model saved.")

Multi-task model saved.
