In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from google.cloud import storage


class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_training_loader(batch_size):
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True)

    return trainloader

def train(dataloader, model, loss_fn, optimizer, epoch, device):
    num_batches = len(dataloader)

    model.train() # Training phase: Dropout and ... are working
    running_loss = 0.0

    for i, (X, y) in enumerate(dataloader, 0):
        batch_size = len(X)
        num_processed_samples = (i + 1) * batch_size

        # get the inputs; data is a list of [inputs, labels]
        X, y = X.to(device), y.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print every ~8000 samples
        if num_processed_samples % 8000 < batch_size:
            step_str = f'[{epoch + 1}, {i + 1:5d}/{num_batches}]'
            print(f'{step_str} loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

def save_model(model, bucket_name, model_name):
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(f"{model_name}.pt")
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(model, f)


def run(args):
    # Get training data
    trainloader = get_training_loader(args.batch_size)

    # Get cpu or gpu device for training.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    # Instantiate model
    model = ConvNet().to(device)

    # Define loss function and create optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)

    # Train model
    for epoch in range(args.num_epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")
        train(trainloader, model, loss_fn, optimizer, epoch, device)

    # Save model
    save_model(model, args.bucket_name, args.model_name)