In [2]:
import torch
from torch import nn
import torchvision
from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision import transforms

In [3]:
model_choice=1
#load resnet18
if model_choice==1:
    model=torchvision.models.resnet50(pretrained=True)
else:
    model=torchvision.models.resnet18(pretrained=False)



In [4]:
#load the CIFAR-10 dataset, and resize the data into 224x224
myTransforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

cifar10_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=myTransforms)
cifar10_data_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=myTransforms)

#load the data
train_loader=DataLoader(cifar10_data,batch_size=64,shuffle=True, num_workers=0)
test_loader=DataLoader(cifar10_data_test,batch_size=64,shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Modify the input and output of the model to match the CIFAR-10 dataset
inchannel = model.fc.in_features
model.fc = nn.Linear(inchannel, 10)

In [6]:
print(torch.cuda.is_available())

True


In [9]:
#train the model using cuda
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

Epoch=10
#enable verbose
verbose=True
validation=True
validation_per_epoch=2
best_accuracy=0
best_epoch=0
patience=2
model.train()
#train the model
for epoch in range(Epoch):
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        loss=criterion(output,target)
        loss.backward()
        optimizer.step()
    if verbose:
        print("Epoch {}/{}".format(epoch+1,Epoch),"Loss: ",loss.item())
    if validation and (epoch+1)%validation_per_epoch==0:
        print("Validation")
        model.eval()
        correct=0
        total=0
        with torch.no_grad():
            for data,target in test_loader:
                data,target=data.to(device),target.to(device)
                output=model(data)
                _,predicted=torch.max(output.data,1)
                total+=target.size(0)
                correct+=(predicted==target).sum().item()
        validation_accuracy=100*correct/total
        print("Accuracy of the network on the 10000 test images: {}%".format(validation_accuracy),'\n')
        
        if best_accuracy<validation_accuracy:
            best_accuracy=validation_accuracy
            best_epoch=epoch
        elif epoch-best_epoch>patience:
            print("Early stopping")
            break
        model.train()

tensor([[ 1.0992e+00,  3.9075e+00, -8.9202e-01, -9.4518e-01, -7.3434e-01,
         -9.1207e-01, -1.7564e+00, -8.1872e-01,  7.0015e-01,  1.3082e+00],
        [ 1.2074e+00,  3.6372e+00, -7.1092e-01, -1.3833e+00, -1.6807e+00,
         -8.1355e-01, -2.2545e+00, -9.8465e-01,  2.5228e-01,  2.9522e+00],
        [ 1.0740e+00, -1.7928e+00,  7.1463e-01,  4.2458e-01,  1.2852e+00,
         -5.5438e-03, -7.1832e-01,  1.6732e+00, -1.9022e+00, -1.4753e+00],
        [ 1.9824e+00, -3.9121e-01,  5.0685e-01, -5.8304e-01, -1.3163e+00,
         -2.5366e-01, -1.9195e+00, -1.3829e+00,  3.6803e+00,  2.6224e-01],
        [-1.3330e+00, -1.6419e+00,  2.0493e+00,  1.0045e+00,  1.5042e+00,
         -7.3652e-01,  3.9831e+00, -1.2459e+00, -1.6538e+00, -2.0315e+00],
        [-4.1329e-01, -1.9064e+00,  1.3228e+00, -9.4978e-01,  3.6185e+00,
         -1.1371e+00,  3.6375e-01,  2.1664e+00, -1.5470e+00, -1.2630e+00],
        [-7.5323e-01, -1.9028e+00,  6.8478e-01,  1.2913e+00,  1.8111e-01,
          2.7138e+00,  9.5150e-0

KeyboardInterrupt: 

In [None]:
#test the model
model=model.to(device)
model.eval()
correct=0
total=0
with torch.no_grad():
    for batch_idx,(data,target) in enumerate(test_loader):
        data,target=data.to(device),target.to(device)
        output=model(data)
        _,predicted=torch.max(output.data,1)
        total+=target.size(0)
        correct+=(predicted==target).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))


Accuracy of the network on the 10000 test images: 96 %


In [None]:
#save the model to model_dir
torch.save(model.state_dict(),'model/model.pth')