In [5]:
import numpy as np

import torch 
from torch import nn, optim 
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from IPython.display import clear_output

import time 

from lib import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
class my_TT_linear(nn.Module):
    def __init__(self, inp_modes, out_modes, ranks):
        super().__init__()
        
        self.d = len(inp_modes)
        self.inp_modes = inp_modes 
        self.out_modes = out_modes 
        self.inp_size = np.prod(inp_modes)
        self.out_size = np.prod(out_modes)
        self.ranks = ranks
        
        self.W_cores = self.init_W_cores(inp_modes, out_modes)
        self.b = torch.nn.Parameter(torch.ones(self.out_size))
        
    def init_W_cores(self, inp_modes, out_modes):
        cores = torch.nn.ParameterList()
        for k in range(self.d):
            core = torch.randn(self.ranks[k], inp_modes[k], out_modes[k], self.ranks[k+1])
            core *= 2 / (inp_modes[k] + out_modes[k])
            cores.append(torch.nn.Parameter(core))
        return cores
            
    def forward(self, inp):
        W = TensorTrain(self.W_cores, self.inp_modes, self.out_modes, self.ranks)
        out = inp * W + self.b
        return out
    
    def backward(self, out):
        grad = self.init_W_cores(inp_modes, out_modes)
        return grad
    
    
class My_Optimizer(torch.optim.Optimizer):
    def __init__(self, optimizer):
        self.optimizer = optimizer
    
#     def step(self):
#         pass
    
    def __getattr__(self, attrname):
        print(attrname)
        return getattr(self.optimizer, attrname) 

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        self.net= nn.Sequential(
            my_TT_linear([8, 8, 8], [4, 4, 4], [1, 2, 2, 1]), nn.ReLU(),
            nn.Linear(64, 10), nn.Softmax(dim=1),
        )

    def forward(self, x):
        return self.net(x)

In [6]:
N = 1000
inp_size = 512
out_size = 10

X = torch.rand(N, inp_size)
W_true = torch.rand(inp_size, out_size) * 2 / (inp_size + out_size)
y = torch.argmax(X @ W_true, 1)

In [7]:
from torch.utils.data import TensorDataset, DataLoader


trainset = TensorDataset(X, y)
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=1)

In [25]:
model = Net().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_loss, train_accuracy = train(model, train_dataloader, 
                                   criterion, optimizer, 
                                   n_epochs=1)

tensor(2.3556, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3503, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3556, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3340, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3193, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3523, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3386, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3487, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3413, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3320, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3386, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3367, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3385, grad_fn=<NllLossBackward0>) torch.Size([16]) torch.Size([])
tensor(2.3290, grad_fn=<N

In [20]:
model = Net().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer = My_Optimizer(optimizer)

train_loss, train_accuracy = train(model, train_dataloader, 
                                   criterion, optimizer, 
                                   n_epochs=5, show=True)

_zero_grad_profile_name
_zero_grad_profile_name
param_groups


TypeError: step() missing 1 required positional argument: 'closure'

In [78]:
optimizer.load_state_dict

<bound method Optimizer.load_state_dict of My_Optimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)>