In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

batch_size = 128

using cuda


In [3]:
def load_imagenet_data(data_dir, batch_size=32):
    # Define the transformations. ImageNet models expect input size of 224x224.
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_trainsform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    # Load ImageNet training and validation datasets
    train_dataset = datasets.ImageNet(root=data_dir, split='train', transform=train_transform)
    val_dataset = datasets.ImageNet(root=data_dir, split='val', transform=val_trainsform)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader

data_dir = '/datasets/ILSVRC2012'  # Ensure this is the correct path to your ImageNet data
train_loader, val_loader = load_imagenet_data(data_dir, batch_size = batch_size)

In [4]:
import time

# Number of batches to read
num_batches_to_read = 100

start_time = time.time()
for i, batch in enumerate(val_loader):
    if i >= num_batches_to_read:
        break
    # Optionally, process the batch here if needed
end_time = time.time()
print(end_time - start_time)

In [5]:
#import timm

def initialize_model(model_name='resnet50', num_classes=1000):
    if model_name == 'resnet50':
        model = models.resnet50()
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    # elif model_name == 'vit_base_patch16_224':
    #     model = timm.create_model('vit_base_patch16_224', pretrained=use_pretrained, num_classes=num_classes)
    else:
        raise Exception("Model not supported: {}".format(model_name))

    return model

model = initialize_model('resnet50', num_classes = 1000)  # Change to 'vit_base_patch16_224' for ViT

In [6]:
def train_model(model, train_loader, criterion, device, optimizer, num_epochs=10):
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        tqdm_epoch = tqdm(train_loader, desc=f'Training Epoch {epoch + 1}', total=len(train_loader))
        for images, labels in tqdm_epoch:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            tqdm_epoch.set_postfix(loss=loss.item())
        tqdm_epoch.close()

def validate_model(model, val_loader, criterion, device):
    model.to(device)
    model.eval()

    total = 0
    correct = 0
    val_loss = 0.0
    with torch.no_grad():
        tqdm_val = tqdm(val_loader, desc='Validation', total=len(val_loader))
        for images, labels in tqdm_val:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            tqdm_val.set_postfix(val_loss=val_loss / len(val_loader), accuracy=100 * correct / total)
        tqdm_val.close()


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

train_model(model, train_loader, criterion, device, optimizer, num_epochs=10)
validate_model(model, val_loader, criterion, device)


Training Epoch 1:  20%|█▉        | 1986/10010 [09:12<37:09,  3.60it/s, loss=5.76]