In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchinfo import summary
from tqdm import tqdm
import matplotlib.pyplot as plt

import time
import os
import copy

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = None

print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
print("Pytorch device: ", device)


PyTorch Version:  1.13.1
Torchvision Version:  0.14.1a0
Pytorch device:  mps


In [4]:
from torchvision.models import resnet

resnet18 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=resnet.ResNet18_Weights.IMAGENET1K_V1)
resnet34 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', weights=resnet.ResNet34_Weights.IMAGENET1K_V1)
resnet50 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=resnet.ResNet50_Weights.IMAGENET1K_V1)
resnet101 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', weights=resnet.ResNet101_Weights.IMAGENET1K_V1)
resnet152 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', weights=resnet.ResNet152_Weights.IMAGENET1K_V1)

Using cache found in /Users/marakim/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/marakim/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/marakim/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/marakim/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/marakim/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
resnets = [
    resnet18,
    resnet34,
    resnet50,
    resnet101,
    resnet152,
]

In [6]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
training_data = torchvision.datasets.ImageFolder('~/Documents/research/imagenet/ILSVRC/Data/CLS-LOC/train/', transform=preprocess)

In [13]:
from torchdata.datapipes.iter import IterableWrapper

def train_model(model, training_data, num_epochs=10):
    training_loader = torch.utils.data.DataLoader2(
        training_data,
        batch_size=256,
    )
    #testing_loader = torch.utils.data.DataLoader2(IterableWrapper(testing_data), batch_size=256)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.AdamW(model.parameters())

    model.train()

    model.to(device)
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for inputs, labels in tqdm(training_loader):
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            loss = criterion(outputs.to(device), labels.to(device))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'[epoch {epoch}]: loss: {epoch_loss}')

In [14]:
train_model(resnet18, training_data)

  4%|▍         | 211/5005 [05:50<2:12:37,  1.66s/it]


KeyboardInterrupt: 