In [None]:
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

In [5]:
def train(net, train_iter, test_iter, num_epochs, lr, optimizer=None, device=d2l.try_gpu()):
    """
    Trains a network 'net'. Assumes that net.init_weights exists
    """
    # 1: initialise weights
#     net.apply(net.init_weights)
    def init_weights_test(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights_test)

    # 2: move model to device for training
    net.to(device)
    
    # 3: set up optimizer, loss function, and animation stuff
    loss = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(params=net.parameters(), lr=lr)
    if optimizer is None:
        optimizer = torch.optim.SGD(params=net.parameters(), lr=lr)
    animator = d2l.Animator(xlabel="epoch number", xlim=[0, num_epochs], legend=["train loss", "train acc", "test acc"])
    
    # 4: training loop
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(3)
        for i, (X, y) in enumerate(train_iter):
            X, y = X.to(device), y.to(device)
            net.train()
            optimizer.zero_grad()
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            # temporarily disable grad to calculate metrics
            with torch.no_grad():
                train_loss = l
#                 import ipdb; ipdb.set_trace()
                _, preds = torch.max(y_hat, 1)
                train_acc = ((preds == y).sum()) / float(X.shape[0])
            if (i + 1) % 50 == 0:
                animator.add(epoch + (i / len(train_iter)), (train_loss, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter, device)
        animator.add(epoch + 1, (None, None, test_acc))
    
    print(f'loss {train_loss:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')

In [7]:
def evaluate_loss(net, data_iter, device, loss=nn.CrossEntropyLoss()):
    net.eval()
    total_loss = 0
    for i, (X, y) in enumerate(data_iter):
        X, y = X.to(device), y.to(device)
        total_loss += loss(net(X), y)

In [8]:
def evaluate_accuracy_gpu(net, data_iter, device):
    """
    Evaluate the accuracy of the model given by 'net' on 
    the DataLoader given by 'data_iter' using the device 'device'
    """
    net.eval()
    num_correct, num_total = 0, 0
    for i, (X, y) in enumerate(data_iter):
        X, y = X.to(device), y.to(device)
        _, predicted = torch.max(net(X), 1)
        correct = (predicted == y).sum()
        num_correct += correct
        num_total += y.shape[0]
    return float(num_correct) / num_total

In [9]:
def get_fashion_mnist_iters(batch_size=128, resize=224):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((resize, resize)),
        torchvision.transforms.ToTensor()
    ])

    all_set = torchvision.datasets.FashionMNIST("./data", transform=transform, download=True)
    test_set = torchvision.datasets.FashionMNIST("./data", transform=transform, download=True, train=False)

    # Build a validation set with an 80-20 split
    val_idx = int(0.8 * len(all_set))

    train_set, val_set = torch.utils.data.random_split(all_set, [val_idx, len(all_set) - val_idx])

    all_iter = torch.utils.data.DataLoader(all_set, batch_size=batch_size, shuffle=True)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
    return (all_iter, train_iter, val_iter, test_iter)

In [1]:
def pretty_size(size):
	"""Pretty prints a torch.Size object"""
	assert(isinstance(size, torch.Size))
	return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
	"""Prints a list of the Tensors being tracked by the garbage collector."""
	import gc
	total_size = 0
	for obj in gc.get_objects():
		try:
			if torch.is_tensor(obj):
				if not gpu_only or obj.is_cuda:
					print("%s:%s%s %s" % (type(obj).__name__, 
										  " GPU" if obj.is_cuda else "",
										  " pinned" if obj.is_pinned else "",
										  pretty_size(obj.size())))
					total_size += obj.numel()
			elif hasattr(obj, "data") and torch.is_tensor(obj.data):
				if not gpu_only or obj.is_cuda:
					print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
												   type(obj.data).__name__, 
												   " GPU" if obj.is_cuda else "",
												   " pinned" if obj.data.is_pinned else "",
												   " grad" if obj.requires_grad else "", 
												   " volatile" if obj.volatile else "",
												   pretty_size(obj.data.size())))
					total_size += obj.data.numel()
		except Exception as e:
			pass        
	print("Total size:", total_size)