In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import numpy as np
from torchvision import datasets, transforms


# load patches data from files
train_images_patches = np.load('data/mnist25_train_patches.npy')
test_images_patches = np.load('data/mnist25_test_patches.npy')

# get the label from datasets.MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_labels = [label for _, label in train_dataset]
test_labels = [label for _, label in test_dataset]
train_labels = torch.LongTensor(train_labels)
test_labels = torch.LongTensor(test_labels)

# make them to be PyTorch tensors, and dataloader
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_images_patches), train_labels)
test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(test_images_patches), test_labels)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# test dataloader
for images, labels in train_loader:
    print(images.shape)
    print(labels.shape)
    break


torch.Size([32, 25, 5, 5])
torch.Size([32])


In [None]:
# def a network.
class MeanNet(nn.Module):
    def __init__(self):
        super(MeanNet, self).__init__()
        self.fc1 = nn.Linear(25, 20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x
    

model = MeanNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# train the network
model.train()
max_acc, min_loss = 0, 100
for epoch in range(100):
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = images.view(-1, 25, 25)
        images = images.mean(dim=1)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

    # test model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.view(-1, 25, 25)
            images = images.mean(dim=1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        # save best model, both accuracy and loss
        if correct/total > max_acc and loss.item() < min_loss:
            max_acc = correct/total
            min_loss = loss.item()
            #torch.save(model.state_dict(), 'data/patch_mnist.pth')
            print('saved at epoch {}, acc {}, loss {}'.format(epoch, max_acc, min_loss))

In [None]:
# def a network.
class ConvNet(nn.Module):
    def __init__(self):
        super(MeanNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(3, 1, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(1, 1, kernel_size=5, stride=5, padding=0)
        self.fc1 = nn.Linear(25, 20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = x.view(-1, 25)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x
    

model = MeanNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# train the network
model.train()
max_acc, min_loss = 0, 100
for epoch in range(100):
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = images.view(-1, 25, 25)
        images = images.mean(dim=1)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

    # test model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.view(-1, 25, 25)
            images = images.mean(dim=1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        # save best model, both accuracy and loss
        if correct/total > max_acc and loss.item() < min_loss:
            max_acc = correct/total
            min_loss = loss.item()
            #torch.save(model.state_dict(), 'data/patch_mnist.pth')
            print('saved at epoch {}, acc {}, loss {}'.format(epoch, max_acc, min_loss))

In [9]:
# test shape
conv1 = nn.Conv2d(1, 1, kernel_size=5, stride=5)
x = torch.randn(1, 25, 25)
y = conv1(x).view(-1, 25)
print(y.shape)

torch.Size([1, 25])
