In [1]:
# Imports
import torch
from tqdm import tqdm
from torchvision import datasets, transforms

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

In [None]:
# Download resnet 100 and put in gpu
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet101', pretrained=False)
# Replace last layer with 10 nodes
model.fc = torch.nn.Linear(2048, 10)

# Put model in gpu
model = model.cuda()

In [None]:
# Download cifar 10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False,
                           download=True, transform=transform)


In [None]:
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=1)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=1)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()


In [None]:
# Acc before training

model.eval()
correct = 0
total = 0
with torch.no_grad():
    with tqdm(testloader, unit="batch") as t2epoch:
        for data, target in t2epoch:
            t2epoch.set_description("Test")
            data, target = data.cuda(), target.cuda()
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            t2epoch.set_postfix(Accuracy=(100 * correct / total))


In [None]:
# Train
for epoch in range(10):
    with tqdm(trainloader, unit="batch") as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            tepoch.set_postfix(loss=loss.item())

    # Test
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        with tqdm(testloader, unit="batch") as t2epoch:
            for data, target in t2epoch:
                t2epoch.set_description(f"Epoch {epoch}")
                data, target = data.cuda(), target.cuda()
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                t2epoch.set_postfix(Accuracy=(100 * correct / total))

    print(" ")


In [None]:
#redefine model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet101', pretrained=False)
model.fc = torch.nn.Linear(2048, 10)
model = model.cuda()

In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
# Train
for epoch in range(1,11):
    model.train()
    with tqdm(trainloader, unit="batch") as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            tepoch.set_postfix(loss=loss.item())

    # Test
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        with tqdm(testloader, unit="batch") as t2epoch:
            for data, target in t2epoch:
                t2epoch.set_description(f"Epoch {epoch}")
                with torch.cuda.amp.autocast():
                    data, target = data.cuda(), target.cuda()
                    output = model(data)
                    _, predicted = torch.max(output.data, 1)
                    total += target.size(0)
                    correct += (predicted == target).sum().item()
                t2epoch.set_postfix(Accuracy=(100 * correct / total))

    print(" ")
