In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torchvision import transforms
from torch.utils.data import DataLoader

from src.utils import create_label_mapping
from model.resnet import resnet18, ResNet18_Weights
from src.model_training import train_model
from data.ImageNetV2.superclassing_dataset import SuperclassImageNetV2Dataset, MappedSuperclassImageNetV2Dataset


SUPERCLASS_DATASET_PATH="./data/ImageNetV2/raw/"
IMAGENET_CLASS_INDEX_PATH = "./data/ImageNetV2/imagenet_class_index.json"
SUPER_CLASS_INDEX_PATH = "./data/ImageNetV2/superclass/superclass_index.json"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

superclass_distribution={-1: 9, 0: 64, 7: 42, 4: 130, 3: 9, 5: 22, 8: 21, 2: 13, 1: 16, 6: 12}

# train superclass network

In [2]:
train_ratio = 0.8
val_ratio = 0.2

super_dataset = SuperclassImageNetV2Dataset(transform=transform, root=SUPERCLASS_DATASET_PATH, superclass_index_path=SUPER_CLASS_INDEX_PATH)

train_size = int(train_ratio * len(super_dataset))
val_size = len(super_dataset) - train_size

train_dataset, val_dataset = random_split(super_dataset, [train_size, val_size])

train_loader = DataLoader(super_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [3]:
num_classes=superclass_distribution.get(-1)
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
superclass_network = train_model(model, -1, train_loader, val_loader, criterion, optimizer, device=device)
torch.save(superclass_network.state_dict(), f"./results/superclass_network.pth")

Epoch 0/0
----------
train Loss: 0.3772 Acc: 0.8891
val Loss: 0.2303 Acc: 0.9362


# train all subnetwork

In [6]:
for superclass_id, num_classes in superclass_distribution.items():
    if superclass_id==-1:
        continue
    print(f"Training sub-network for superclass ID: {superclass_id} with {num_classes} classes")

    train_ratio = 0.8
    val_ratio = 0.2

    label_mapping = create_label_mapping(super_dataset, superclass_id)

    dataset = MappedSuperclassImageNetV2Dataset(
        superclass=superclass_id,
        transform=transform,
        root=SUPERCLASS_DATASET_PATH,
        superclass_index_path=SUPER_CLASS_INDEX_PATH,
        label_mapping=label_mapping
    )

    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size

    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    sub_model = resnet18(weights=ResNet18_Weights.DEFAULT)
    sub_model.fc = nn.Linear(sub_model.fc.in_features, num_classes)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(sub_model.parameters(), lr=0.001, momentum=0.9)

    subclass_network = train_model(sub_model, superclass_id, train_loader, val_loader, criterion, optimizer, device=device)
    torch.save(subclass_network.state_dict(), f"./results/subclass_network_{superclass_id}.pth")


Training sub-network for superclass ID: 0 with 64 classes
Epoch 0/0
----------
train Loss: 4.2861 Acc: 0.0195
val Loss: 4.3301 Acc: 0.0078
Training sub-network for superclass ID: 7 with 42 classes
Epoch 0/0
----------


KeyboardInterrupt: 