In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm

# ---- Paths ---- #
IMAGE_DIR = "/home/kajm20/mnist/ILSVRC/Data/CLS-LOC/train"
ANNOTATION_DIR = "/home/kajm20/mnist/ILSVRC/Annotations/CLS-LOC/train"
MAPPING_PATH = "/home/kajm20/mnist/ILSVRC/LOC_synset_mapping.txt"

# ---- Label Mapping ---- #
wordnet_to_imagenet = {}
with open(MAPPING_PATH) as f:
    for idx, line in enumerate(f.readlines()):
        wordnet_id, _ = line.strip().split(' ', 1)
        wordnet_to_imagenet[wordnet_id] = idx

# ---- Transform ---- #
imagenet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# ---- Custom Dataset ---- #
class ImageNetTrainDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.samples = self._gather_samples()

    def _gather_samples(self):
        samples = []
        for class_dir in sorted(os.listdir(self.annotation_dir)):
            class_annotation_dir = os.path.join(self.annotation_dir, class_dir)
            if not os.path.isdir(class_annotation_dir):
                continue
            for filename in os.listdir(class_annotation_dir):
                if filename.endswith(".xml"):
                    xml_path = os.path.join(class_annotation_dir, filename)
                    samples.append(xml_path)
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        xml_path = self.samples[idx]
        tree = ET.parse(xml_path)
        root = tree.getroot()
        wordnet_id = root.find("object").find("name").text
        class_idx = wordnet_to_imagenet.get(wordnet_id, -1)

        image_filename = root.find("filename").text + ".JPEG"
        image_path = os.path.join(self.image_dir, wordnet_id, image_filename)

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, class_idx

# ---- Dataset & Loader ---- #
train_dataset = ImageNetTrainDataset(IMAGE_DIR, ANNOTATION_DIR, transform=imagenet_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

# ---- Model ---- #
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.efficientnet_b0(weights=None)
model.to(device)

# ---- Loss and Optimizer ---- #
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ---- Training Loop ---- #
def train(model, dataloader, criterion, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        acc = 100 * correct / total
        print(f"Epoch {epoch+1} Loss: {running_loss:.4f}, Accuracy: {acc:.2f}%")

# ---- Run Training ---- #
train(model, train_loader, criterion, optimizer, device, num_epochs=10)


Epoch 1/10: 100%|██████████| 8509/8509 [10:53:18<00:00,  4.61s/it] 


Epoch 1 Loss: 44328.8021, Accuracy: 7.56%


Epoch 2/10:   2%|▏         | 156/8509 [12:18<10:06:20,  4.36s/it]