In [2]:
import torch
from torchvision import models, transforms
import torchvision.datasets as dsets
from torch.utils.data import DataLoader
from torch import nn, manual_seed, optim

manual_seed(0)

<torch._C.Generator at 0x299ebd214f0>

In [3]:
model = models.resnet18(pretrained=True)

In [4]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    transforms.Normalize(mean, std)
])
train_dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transform)
validation_dataset = dsets.MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(512, 10)

In [6]:
criterion = nn.CrossEntropyLoss()
learning_rate = 0.03
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=learning_rate)
train_loader = DataLoader(dataset=train_dataset, batch_size=128)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=5000)

In [7]:
def train_model(n_epochs):
    cost_list = []
    acc_list = []
    n_test = len(validation_dataset)
    for epoch in range(n_epochs):
        print(f"Starting train epoch {epoch + 1}")
        COST = 0
        model.train()
        for x, y in train_loader:
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            COST += loss.data
        print(f"Done.")
        print(f"\tTraining Cost: {COST}")

        cost_list.append(COST)
        correct = 0
        # measure accuracy on the validation data
        model.eval()
        for x_test, y_test in validation_loader:
            out = model(x_test)
            _, yhat = torch.max(out.data, 1)
            correct += (yhat == y_test).sum().item()
        accuracy = correct / n_test
        acc_list.append(accuracy)
        print(f"\tValidation Accuracy: {accuracy}")


train_model(10)

Starting train epoch 1
Done.
	Training Cost: 843.1636962890625
	Validation Accuracy: 0.694
Starting train epoch 2
Done.
	Training Cost: 816.0101928710938
	Validation Accuracy: 0.6998
Starting train epoch 3
Done.
	Training Cost: 864.1507568359375
	Validation Accuracy: 0.7043
Starting train epoch 4
Done.
	Training Cost: 898.8692626953125
	Validation Accuracy: 0.7058
Starting train epoch 5
Done.
	Training Cost: 929.3705444335938
	Validation Accuracy: 0.7177
Starting train epoch 6
Done.
	Training Cost: 936.6265869140625
	Validation Accuracy: 0.707
Starting train epoch 7
Done.
	Training Cost: 940.835693359375
	Validation Accuracy: 0.7084
Starting train epoch 8
Done.
	Training Cost: 945.0850830078125
	Validation Accuracy: 0.7086
Starting train epoch 9
Done.
	Training Cost: 957.314453125
	Validation Accuracy: 0.7154
Starting train epoch 10


KeyboardInterrupt: 