In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch.optim as optim
import math
import numpy as np
import tednet.tednet.tnn.tensor_ring as tednet_tr
import tednet.tednet.tnn.tensor_train as tednet_tt

# Model setup

## Tensor Layer definitions

In [2]:
class TRLinearLayer(nn.Module):
    def __init__(self, in_shape, out_shape, ranks, bias: bool = True):
        super().__init__()
        self.n_info = {}
        self.layer = tednet_tr.TRLinear(in_shape, out_shape, ranks, bias=bias)
        self.n_info["ori_params"] = self.layer.tn_info["ori_params"]
        self.n_info["t_params"] = self.layer.tn_info["t_params"]
    def forward(self, x):
        return self.layer(x)

class TTLinearLayer(nn.Module):
    def __init__(self, in_shape, out_shape, ranks, bias: bool = True):
        super().__init__()
        self.n_info = {}
        self.layer = tednet_tt.TTLinear(in_shape, out_shape, ranks, bias=bias)
        self.n_info["ori_params"] = self.layer.tn_info["ori_params"]
        self.n_info["t_params"] = self.layer.tn_info["t_params"]
    def forward(self, x):
        return self.layer(x)



## Model definitions
We test different tensor decompositions methods by training a fully connected network model over the MNIST dataset.
The test model has layers [784, 320, 100, 10]

In [3]:
# Define the LeNet-300-100 model
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(784, 320),
            nn.ReLU(),
            nn.Linear(320, 100),
            nn.ReLU(),
            nn.Linear(100, 10), # Returns logits
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.classifier(x)
        return x

class LeNetTR(nn.Module):
    def __init__(self):
        super(LeNetTR, self).__init__()
        self.n_info = {}
        self.flatten = nn.Flatten()
        self.l1 = TRLinearLayer([7,7,4,4], [5,8,8], [7,7,4,4,5,8,8], bias=True)
        self.l2 = TRLinearLayer([5,8,8], [10, 10], [5,5,5,5,5], bias=True)
        self.l3 = TRLinearLayer([10,10], [10], [5,5,5], bias=True)
        self.classifier = nn.Sequential(
            self.l1,
            nn.ReLU(),
            self.l2,
            nn.ReLU(),
            self.l3, # Returns logits
        )
        self.n_info["ori_params"] = sum([x.n_info["ori_params"] for x in [self.l1, self.l2, self.l3]])
        self.n_info["t_params"] = sum([x.n_info["t_params"] for x in [self.l1, self.l2, self.l3]])
        self.n_info["cr"] = self.n_info["ori_params"] / self.n_info["t_params"]

        print("LeNet TR ---")
        print("Original params: " + str(self.n_info["ori_params"]))
        print("TN params: " + str(self.n_info["t_params"]))
        print("Compression ratio: " + str(self.n_info["cr"]))

    def forward(self, x):
        x = self.flatten(x)
        x = self.classifier(x)
        return x
        
class LeNetTT(nn.Module):
    def __init__(self):
        super(LeNetTT, self).__init__()
        self.n_info = {}
        self.flatten = nn.Flatten()
        self.l1 = TTLinearLayer([7,7,4,4], [5,4,4,4], [4,4,4], bias=True)
        self.l2 = TTLinearLayer([5,4,4,4], [5,5,2,2], [4,4,4], bias=True)
        self.l3 = TTLinearLayer([5,5,2,2], [1,1,1,10], [4,4,4], bias=True)
        self.classifier = nn.Sequential(
            self.l1,
            nn.ReLU(),
            self.l2,
            nn.ReLU(),
            self.l3, # Returns logits
        )
        self.n_info["ori_params"] = sum([x.n_info["ori_params"] for x in [self.l1, self.l2, self.l3]])
        self.n_info["t_params"] = sum([x.n_info["t_params"] for x in [self.l1, self.l2, self.l3]])
        self.n_info["cr"] = self.n_info["ori_params"] / self.n_info["t_params"]

        print("LeNet TT ---")
        print("Original params: " + str(self.n_info["ori_params"]))
        print("TN params: " + str(self.n_info["t_params"]))
        print("Compression ratio: " + str(self.n_info["cr"]))

    def forward(self, x):
        x = self.flatten(x)
        x = self.classifier(x)
        return x

# Training functions

## Training helper functions

In [4]:
def evaluate_validation_set(model, validation_loader, device):
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0
    with torch.no_grad():  # Disable gradient computation
        for x_val, y_val in validation_loader:
            x_val = x_val.to(device)
            y_val = y_val.to(device)

            outputs = model(x_val)
            _, predicted = torch.max(outputs.data, 1)
            total += y_val.size(0)
            correct += (predicted == y_val).sum().item()

    accuracy = correct / total
    return accuracy


def show_example_batch(train_data):
    # Get the first batch
    train_loader, valid_loader = train_data
    
    dataiter = iter(train_loader)
    images, labels = next(dataiter)
    
    # Plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    plot_size=20
    for idx in np.arange(plot_size):
        ax = fig.add_subplot(2, int(plot_size/2), idx+1, xticks=[], yticks=[])
        ax.imshow(np.squeeze(images[idx].numpy()), cmap='gray')
        # print out the correct label for each image
        # .item() gets the value contained in a Tensor
        ax.set_title(str(labels[idx].item()))
    plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
def get_train_data():
    transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,)),
                                  ])
    # Download and load the training data
    trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)
    validation_size = 0.10
    num_train = len(trainset)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    # Calculate the number of data points in the validation set
    split = int(np.floor(validation_size * num_train))
    print(split)
    
    # Train_idx => Imatges per entrenar
    # Valid_idx => Imatges per verificar i comprovar el model
    train_idx, valid_idx = indices[split:], indices[:split]
    
    # Create data samplers
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx)
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(trainset, batch_size=64, sampler=valid_sampler)

    return train_loader, valid_loader

## Train model function

In [20]:
def train_model(model, device, train_data, learning_rate=1.2e-3, batch_size=60, epochs=10):
    losses = []
    accuracy_list = []
    val_accuracy_list = []
    grad_norms = []
    print("Training model")
    print(model)

    train_loader, valid_loader = train_data
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for param in model.parameters():
        print(param)
    # Validation Calculation
    val_check_iter = 3
    iterations_per_epoch = len(train_loader)
    
    for epoch in range(epochs):
        for i, (x_batch, y_batch) in enumerate(train_loader):
            # Move data to the device
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
    
            optimizer.zero_grad()
            output = model(x_batch)
            loss = nn.CrossEntropyLoss()(output, y_batch)
            loss.backward()
    
    
            # Calculate and store gradient norm
            total_grad_norm = torch.sqrt(sum(p.grad.norm()**2 for p in model.parameters() if p.grad is not None))
            grad_norms.append(total_grad_norm.item())
    
            optimizer.step()
    
            # Store the loss
            losses.append(loss.item())
    
            # Store the accuracy 
            _, argmax = torch.max(output, 1)
            accuracy = (y_batch == argmax.squeeze()).float().mean()
            accuracy_list.append(accuracy)
    
            if i % int(len(train_loader) / 50) == 0:
                print(".", end='')
    
            if i % int(len(train_loader) / val_check_iter) == 0:
                # Calculate validation accuracy at the end of each epoch
                val_accuracy = evaluate_validation_set(model, valid_loader, device)
                val_accuracy_list.append(val_accuracy)
                
                # Calculate average loss and accuracy over an epoch
                avg_loss = torch.mean(torch.tensor(losses[-len(train_loader):]))
                avg_accuracy = torch.mean(torch.tensor(accuracy_list[-len(train_loader):]))
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss.item():.4f}, Accuracy: {avg_accuracy.item():.4f}, Val Accuracy: {val_accuracy:.4f}')
    infos = {
        "losses": losses,
        "accuracy_list": accuracy_list,
        "grad_norms": grad_norms,
        "val_accuracy_list": val_accuracy_list,
        "iterations_per_epoch": iterations_per_epoch,
        "model_name": model.__class__.__name__
    }
    return infos

## Show training results function

In [21]:
def show_training_infos(training_infos, epochs):
    
    average_window = 200
    infos = []
    for training_info in training_infos:
        losses = torch.tensor(training_info["losses"]).cpu().numpy()
        accuracy_list = torch.tensor(training_info["accuracy_list"]).cpu().numpy()
        grad_norms = torch.tensor(training_info["grad_norms"]).cpu().numpy()
        val_accuracy_list = training_info["val_accuracy_list"]
        iterations_per_epoch = training_info["iterations_per_epoch"]

        average_losses = [np.mean(losses[i-average_window:i]) for i in range(0, len(losses), average_window)]
        avg_accuracy_list = [np.mean(accuracy_list[i-average_window:i]) for i in range(0, len(accuracy_list), average_window)]
        avg_grad_norms = [np.mean(grad_norms[i-average_window:i]) for i in range(0, len(grad_norms), average_window)]

        iterations = np.arange(len(average_losses)) * average_window
        x_iterations = np.arange(0, iterations_per_epoch * epochs,(iterations_per_epoch * epochs) / len(val_accuracy_list))
        infos.append({
            "name": training_info["model_name"],
            "iterations": iterations,
            "average_losses": average_losses,
            "avg_grad_norms": avg_grad_norms,
            "avg_accuracy_list": avg_accuracy_list,
            "val_accuracy_list": val_accuracy_list,
            "x_iterations": x_iterations
        })
    
    # Plotting the loss curve (average, with max and min as error bands)
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    for info in infos:
        plt.plot(info["iterations"], info["average_losses"], label=info["name"])
    
    plt.ylim(0, .5)
    
    plt.title('Training Loss Curve Average')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plotting the gradient norm curve
    plt.subplot(1, 2, 2)
    for info in infos:
        plt.plot(info["iterations"], info["avg_grad_norms"], label=info["name"])
    plt.title('Gradient Norm Curve Average')
    plt.xlabel('Iterations')
    plt.ylabel('Gradient Norm')
    plt.legend()
    
    # Plotting the accuracy curve (average, with max and min as error bars)
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    for info in infos:
        plt.plot(info["iterations"], info["avg_accuracy_list"], label=info["name"])
    plt.title('Training Accuracy Curve Average')
    plt.xlabel('Iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Zoom to 0.9 to 1.0 range
    plt.ylim(0.9, 1.0)
    
    # Plotting the validation accuracy curve
    x_iterations = np.arange(0, iterations_per_epoch * epochs,(iterations_per_epoch * epochs) / len(val_accuracy_list))
    plt.subplot(1, 2, 2)
    for info in infos:
        plt.plot(info["x_iterations"], info["val_accuracy_list"], label=info["name"])
    plt.title('Validation Accuracy Curve Average')
    plt.xlabel('Iterations')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Zoom to 0.94 to 1.0 range
    plt.ylim(0.94, 1.0)
    
    plt.tight_layout()
    plt.show()

## Model test function

In [22]:
# Test the model
def test_model(model):
    model.eval()
    test_loss = 0
    correct = 0
    print(model)
    
    # Download and load the test data
    testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)
    test_loader = DataLoader(testset, batch_size=64, shuffle=True)
    
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
    
            output = model(x_batch)
            test_loss += nn.CrossEntropyLoss()(output, y_batch).item()
    
            _, argmax = torch.max(output, 1)
            correct += (y_batch == argmax.squeeze()).float().sum().item()
    
    test_loss /= len(test_loader.dataset)
    test_accuracy = correct / len(test_loader.dataset)
    
    print(f'Test loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

# Actual training

In [23]:
models = [
    LeNetTR(), LeNetTT(), LeNet()
]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for x in models:
    x.to(device)
print("Device:", device)

LeNet TR ---
Original params: 283880
TN params: 3618
Compression ratio: 78.46323935876174
compression_ration is:  276.2995594713656
compression_ration is:  55.172413793103445
compression_ration is:  4.716981132075472
LeNet TT ---
Original params: 283880
TN params: 1700
Compression ratio: 166.98823529411766
Device: cuda:0


In [24]:
train_data = get_train_data()
epochs = 20
training_infos = []
for model in models:
    training_infos.append(train_model(model, device, train_data, epochs=epochs))
show_training_infos(training_infos, epochs)

6000
Training model
LeNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (classifier): Sequential(
    (0): Linear(in_features=784, out_features=320, bias=True)
    (1): ReLU()
    (2): Linear(in_features=320, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=10, bias=True)
  )
)
Parameter containing:
tensor([[-0.0328, -0.0008,  0.0249,  ..., -0.0200,  0.0080, -0.0281],
        [ 0.0185,  0.0167,  0.0329,  ...,  0.0293, -0.0023, -0.0325],
        [ 0.0039, -0.0181, -0.0266,  ..., -0.0067, -0.0284,  0.0294],
        ...,
        [-0.0356,  0.0238,  0.0038,  ...,  0.0084,  0.0087, -0.0033],
        [-0.0181,  0.0310,  0.0168,  ..., -0.0184, -0.0035,  0.0098],
        [-0.0248, -0.0201,  0.0252,  ...,  0.0065, -0.0147,  0.0234]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([-1.4427e-02,  3.3118e-02,  5.9915e-04, -2.3209e-02,  3.2347e-02,
        -3.4246e-02,  2.8080e-02, -1.9139e-02,  2.8579e-02,  1.2097e-03,
        

KeyboardInterrupt: 