In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import jdot
import numpy as np

In [16]:
transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.1307,), (0.3081,))
])

In [17]:
train_dataset_source = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset_source = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

target_dataset = torchvision.datasets.USPS(root='./data', download=True, transform=transform)

In [18]:
X_train, y_train = next(iter(DataLoader(train_dataset_source, batch_size=len(train_dataset_source) // 10)))
X_test, y_test = next(iter(DataLoader(test_dataset_source, batch_size=len(test_dataset_source) // 10)))
X_target, y_target = next(iter(DataLoader(target_dataset, batch_size=len(target_dataset))))
X_target = torch.nn.functional.pad(X_target, (6, 6, 6, 6), mode='constant', value=-0.4242)
X_train = X_train.view(-1, 28*28)
X_test = X_test.view(-1, 28*28)
X_target = X_target.view(-1, 28*28)
y_train = torch.nn.functional.one_hot(y_train, num_classes=10).float()
y_test = torch.nn.functional.one_hot(y_test, num_classes=10).float()
y_target = torch.nn.functional.one_hot(y_target, num_classes=10).float()

In [19]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
print(X_target.shape, y_target.shape)

torch.Size([6000, 784]) torch.Size([6000, 10])
torch.Size([1000, 784]) torch.Size([1000, 10])
torch.Size([7291, 784]) torch.Size([7291, 10])


In [20]:
class Model(torch.nn.Module):
	def __init__(self, n_epochs=100, device='mps'):
		super(Model, self).__init__()
		self.model = torch.nn.Sequential(
			torch.nn.Linear(28 * 28, 128),
			torch.nn.ReLU(),
			torch.nn.Dropout(0.5),
			torch.nn.Linear(128, 10),
			torch.nn.Softmax(dim=1)
		).to(device)
		self.n_epochs = n_epochs
		self.device = device
		self.criterion = torch.nn.MSELoss()
		self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

	def fit(self, X, y):
		# fit the model
		# y = torch.nn.functional.one_hot(y, num_classes=10).float().to(self.device)
		for epoch in range(self.n_epochs):
			self.model.train()
			X, y = X.to(self.device), y.to(self.device)
			self.optimizer.zero_grad()
			output = self.model(X)
			loss = self.criterion(output, y)
			loss.backward()
			self.optimizer.step()
			if epoch % 10 == 0:
				print(f'Epoch {epoch}/{self.n_epochs}, Loss: {loss.item()}')
		print(f'Epoch {epoch + 1}/{self.n_epochs}, Loss: {loss.item()}')

	@torch.no_grad()
	def predict(self, X):
		# predict the labels
		self.model.eval()
		X = X.to(self.device)
		predicted = self.model(X)
		# print(output.shape)
		# _, predicted = torch.max(output, dim=1)
		predicted = predicted.cpu().detach()
		return predicted


def make_model(n_epochs=100, device='mps'):
	return Model(n_epochs=n_epochs, device=device)

In [21]:
model = make_model()

In [22]:
print(X_train.shape, y_train.shape)
model.fit(X_train, y_train)

torch.Size([6000, 784]) torch.Size([6000, 10])
Epoch 0/100, Loss: 0.09113242477178574
Epoch 10/100, Loss: 0.04473089426755905
Epoch 20/100, Loss: 0.02574266865849495
Epoch 30/100, Loss: 0.01927793025970459
Epoch 40/100, Loss: 0.016251111403107643
Epoch 50/100, Loss: 0.01392282359302044
Epoch 60/100, Loss: 0.01271914504468441
Epoch 70/100, Loss: 0.010977528057992458
Epoch 80/100, Loss: 0.010337648913264275
Epoch 90/100, Loss: 0.009373887442052364
Epoch 100/100, Loss: 0.009149998426437378


In [23]:
y_pred = model.predict(X_test)
print(f'Accuracy: {(torch.argmax(y_pred, axis=1) == torch.argmax(y_test, axis=1)).float().mean()}')

Accuracy: 0.9100000262260437


In [24]:
y_pred = model.predict(X_target)
print(f'Accuracy: {(torch.argmax(y_pred, axis=1) == torch.argmax(y_target, axis=1)).float().mean()}')

Accuracy: 0.6447675228118896


In [28]:
model = make_model()
model.fit(X_train, y_train)
y_pred = model.predict(X_target)

Epoch 0/100, Loss: 0.0910920724272728
Epoch 10/100, Loss: 0.04670117422938347
Epoch 20/100, Loss: 0.02704324945807457
Epoch 30/100, Loss: 0.020045023411512375
Epoch 40/100, Loss: 0.01620771363377571
Epoch 50/100, Loss: 0.014194965362548828
Epoch 60/100, Loss: 0.012634809128940105
Epoch 70/100, Loss: 0.011202187277376652
Epoch 80/100, Loss: 0.010528859682381153
Epoch 90/100, Loss: 0.009507548063993454
Epoch 100/100, Loss: 0.009115871042013168


In [31]:
print(np.mean(y_pred.numpy() - y_target.numpy())**2)

4.0141113429840795e-23


In [25]:
model, results = jdot.jdot_nn_l2(make_model, X_train, y_train, X_target, y_target, n_epochs=100)

Epoch 0/100, Loss: 0.09204807877540588
Epoch 10/100, Loss: 0.043571509420871735
Epoch 20/100, Loss: 0.025820491835474968
Epoch 30/100, Loss: 0.019205406308174133
Epoch 40/100, Loss: 0.016102982684969902
Epoch 50/100, Loss: 0.013771014288067818
Epoch 60/100, Loss: 0.012746588326990604
Epoch 70/100, Loss: 0.011232252232730389
Epoch 80/100, Loss: 0.010442824102938175
Epoch 90/100, Loss: 0.009645842015743256
Epoch 100/100, Loss: 0.008937009610235691


  result_code_string = check_result(result_code)


torch.Size([7291, 10])
Epoch 0/100, Loss: 0.08941204100847244
Epoch 10/100, Loss: 0.0647502988576889
Epoch 20/100, Loss: 0.048449233174324036
Epoch 30/100, Loss: 0.038064803928136826
Epoch 40/100, Loss: 0.032277822494506836
Epoch 50/100, Loss: 0.02861899323761463
Epoch 60/100, Loss: 0.025833413004875183
Epoch 70/100, Loss: 0.024242978543043137
Epoch 80/100, Loss: 0.022693434730172157
Epoch 90/100, Loss: 0.021499918773770332
Epoch 100/100, Loss: 0.021160224452614784
torch.Size([7291, 10])
Epoch 0/100, Loss: 0.08983055502176285
Epoch 10/100, Loss: 0.0603010393679142
Epoch 20/100, Loss: 0.042515452951192856
Epoch 30/100, Loss: 0.032213352620601654
Epoch 40/100, Loss: 0.026288466528058052
Epoch 50/100, Loss: 0.022990217432379723
Epoch 60/100, Loss: 0.020464060828089714
Epoch 70/100, Loss: 0.01819414086639881
Epoch 80/100, Loss: 0.017262177541851997
Epoch 90/100, Loss: 0.01607838273048401
Epoch 100/100, Loss: 0.015088362619280815
torch.Size([7291, 10])
Epoch 0/100, Loss: 0.09082012623548508

In [14]:
y_pred = model.predict(X_target)
print(f'Accuracy: {(torch.argmax(y_pred, axis=1) == torch.argmax(y_target, axis=1)).float().mean()}')

Accuracy: 0.5054176449775696
