## Подготовка данных

In [2]:
%load_ext autoreload
%autoreload 2

import torch

import jax.numpy as jnp
import scipy
import copy
import sys

import torch
from torch import nn
from torchvision.datasets import CIFAR10
from sklearn.preprocessing import StandardScaler
import numpy as np

from torchvision.models import resnet18
from torchvision.models import resnet50

from tucker_riemopt import Tucker



In [3]:
x = np.random.randn(4096, 10) / 10000000
X = Tucker.full2tuck(x)

In [4]:
import torchvision.transforms as T

# train_set = CIFAR10('CIFAR10', train=True, download=True)
# test_set = CIFAR10('CIFAR10', train=False, download=True)

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_set = CIFAR10('CIFAR10', train=True, transform=T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomCrop(32, 4),
            T.ToTensor(),
            normalize,
        ]), download=True)
test_set = CIFAR10('CIFAR10', train=False, transform=T.Compose([
            T.ToTensor(),
            normalize,
        ]), download=True)

X_train = torch.Tensor(train_set.data)
X_test = torch.Tensor(test_set.data)

y_train = np.array(train_set.targets)
y_test = np.array(test_set.targets)

transforms = torch.nn.Sequential(
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
)

device = 'cuda:0'

# X_train = transforms(X_train.permute([0, 3, 1, 2])).to(device)
# X_test = transforms(X_test.permute([0, 3, 1, 2])).to(device)

X_train = transforms(X_train.permute([0, 3, 1, 2]))
X_test = transforms(X_test.permute([0, 3, 1, 2]))

Files already downloaded and verified
Files already downloaded and verified


In [5]:
from torch.utils.data import TensorDataset, Dataset, DataLoader

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)

  train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
  val_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))


In [6]:
train_dataloader = torch.utils.data.DataLoader(
        CIFAR10('CIFAR10', train=True, transform=T.Compose([
            T.RandomHorizontalFlip(),
            # T.Resize(224),
            T.RandomCrop(32, 4),
            T.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=128, shuffle=False,
        num_workers=2, pin_memory=True)

val_dataloader = torch.utils.data.DataLoader(
        CIFAR10('CIFAR10', train=False, transform=T.Compose([
            # T.Resize(224),
            T.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False,
        num_workers=2, pin_memory=True)

Files already downloaded and verified


## Модели

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.classifier = nn.Sequential(
            nn.AvgPool2d(8),
            nn.Flatten(),
            nn.Linear(64, num_classes),
        )

        self.flatten = nn.Sequential(
            # nn.AvgPool2d(8),
            nn.Flatten(),
        )

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.classifier(out)
        return out
    
    def get_last(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        return out
        
    def get_last_flattened(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.flatten(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])


def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])


def test(net):
    import numpy as np
    total_params = 0

    for x in filter(lambda p: p.requires_grad, net.parameters()):
        total_params += np.prod(x.data.numpy().shape)
    print("Total number of params", total_params)
    print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))


if __name__ == "__main__":
    for net_name in __all__:
        if net_name.startswith('resnet'):
            print(net_name)
            test(globals()[net_name]())
            print()

resnet20
Total number of params 269722
Total layers 20

resnet32
Total number of params 464154
Total layers 32

resnet44
Total number of params 658586
Total layers 44

resnet56
Total number of params 853018
Total layers 56

resnet110
Total number of params 1727962
Total layers 110

resnet1202
Total number of params 19421274
Total layers 1202



## Обычный Linear

In [8]:
from tqdm.notebook import tqdm

def train_one_epoch(model, train_dataloader, criterion, optimizer, device="cuda:0"):
    progress_bar = tqdm(train_dataloader)
    model = model.to(device).train()
    idx = 0
    for (images, labels) in progress_bar:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 10 == 0:
            progress_bar.set_description("Loss = {:.4f}".format(loss.item()))
        idx += 1


def predict(model, val_dataloader, criterion, device="cuda:0"):
    cumulative_loss = 0
    top1_acc = 0
    top5_acc = 0
    model = model.to(device).eval()
    predicted_classes = []
    true_classes = []
    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_dataloader): 
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            predicted_classes.append(preds.argmax(1).float())
            true_classes.append(labels)
            cumulative_loss += loss.item()
            top1_acc += (preds.argsort(axis=1)[:,-1:].T == labels).float().sum()
            top5_acc += (preds.argsort(axis=1)[:,-5:].T == labels).float().sum()
    print(top1_acc)
    print(top5_acc)
    print("Loss = {:.4f}".format(cumulative_loss / idx), "top1 accuracy = {:.4f}".format(top1_acc / len(val_dataloader.dataset)), "top5 accuracy = {:.4f}".format(top5_acc / len(val_dataloader.dataset)))
    return cumulative_loss, torch.cat(predicted_classes).cpu(), torch.cat(true_classes).cpu()


def train(model, train_dataloader, val_dataloader, criterion, optimizer, device="cuda:0", n_epochs=10, scheduler=None):
    model = model.to(device)
    for epoch in range(n_epochs):
        train_one_epoch(model, train_dataloader, criterion, optimizer, device)
        loss, _, _ = predict(model, val_dataloader, criterion, device)
        if scheduler is not None:
            scheduler.step()

In [15]:
# model = resnet32()

# model = model.to(device)

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
optimizer = torch.optim.SGD(model.parameters(), lr=0.)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[90,135])

In [16]:
train(model, train_dataloader, val_dataloader, criterion, optimizer, device, scheduler=scheduler, n_epochs=1)

  0%|          | 0/391 [00:00<?, ?it/s]

tensor(9201., device='cuda:0')
tensor(9973., device='cuda:0')
Loss = 0.4176 top1 accuracy = 0.9201 top5 accuracy = 0.9973


In [17]:
predict(model, val_dataloader, criterion, device)
None

tensor(9201., device='cuda:0')
tensor(9973., device='cuda:0')
Loss = 0.4176 top1 accuracy = 0.9201 top5 accuracy = 0.9973


In [14]:
model.load_state_dict(torch.load('resnet32_adam_9203_copy2'))

<All keys matched successfully>

In [184]:
torch.save(model.state_dict(), 'resnet32_adam')

## Получаем результат после последнего сверточного слоя

In [18]:
def get_last(model, val_dataloader, criterion, device="cuda:0"):
    model = model.to(device).eval()
    result = []
    with torch.no_grad():
        for (images, labels) in tqdm(val_dataloader): 
            images, labels = images.to(device), labels.to(device)
            transformed = model.get_last(images)
            result += [transformed.cpu()]
    return result
  
def get_last_flattened(model, val_dataloader, criterion, device="cuda:0"):
    model = model.to(device).eval()
    result = []
    with torch.no_grad():
        for (images, labels) in tqdm(val_dataloader): 
            images, labels = images.to(device), labels.to(device)
            transformed = model.get_last_flattened(images)
            result += [transformed.cpu()]
    return result

In [19]:
X_train = np.vstack(get_last_flattened(model, train_dataloader, criterion, device))
X_test = np.vstack(get_last_flattened(model, val_dataloader, criterion, device))
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

(50000, 4096)
(10000, 4096)
(50000,)
(10000,)


## Получаем выход до flatten

In [None]:
X_train = np.vstack(get_last(model, train_dataloader, criterion, device))
X_test = np.vstack(get_last(model, val_dataloader, criterion, device))
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

In [None]:
X_train = torch.tensor(X_train).to(device)
y_train = torch.tensor(y_train).to(device)
X_test = torch.tensor(X_test).to(device)
y_test = torch.tensor(y_test).to(device)

## Flatten без римановой оптимизации

In [None]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.symmetric.optim import SGDmomentum as SGDmomentumSym
from tucker_riemopt.symmetric.tucker import Tucker as SymTucker

import wandb

c = 1

rank = (10, 4 * c, 2 * c, 2 * c)

d_in = 64 * 8 * 8
d_out = 10

# model = TuckerLinearSimpleSymmetric(d_in, d_out, nn.Linear(d_in, d_out), rank=rank, dims=[10, 64, 8, 8])

# optimizer = SGDmomentumSym(model.riemann_parameters(), rank=rank, max_lr=1e-3)
criterion = nn.CrossEntropyLoss()

def get_GVU(c):
    rank = (c, c)
    
    U1, _, _ = torch.linalg.svd(torch.randn(4096, rank[0]))
    V1= U1[:, :rank[0]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(10, rank[1]))
    V2 = U1[:, :rank[1]].to(device)

    G = torch.randn(rank).to(device) / 100000
    
    return G, V1, V2

# x_k = SymTucker(G, [V1, V2], num_symmetric_modes=2, symmetric_factor=U)

# optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, U]), rank=rank, max_lr=1e-3)

def forward(X, input):
    output = torch.einsum('ab,bi,cj,ij->ac', input, *X.factors, X.core)
    return output

In [None]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels[:, None]).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels[:, None]).sum() / preds.shape[0]

def train_one_epoch(G, V1, V2, train_loader, optimizer, criterion, wandb_log):
    train_grad_norm = 0
    train_loss = 0
    
    log_every = 100
    
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data
        
        optimizer.zero_grad()
        preds = torch.einsum('ab,bi,cj,ij->ac', inputs, V1, V2, G)
        loss = criterion(preds, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        if (i + 1) % log_every == 0:
            if wandb_log:
                wandb.log({
                    'train_loss': train_loss / log_every
                })
            else:
                # print('Train loss:', train_loss / log_every)
                # print()
                pass
            train_loss = 0
            pass

def eval(G, V1, V2, test_loader, criterion):
    loss = 0
    acc1 = 0
    acc5 = 0
    sz = 0
    
    x = Tucker(G, [V1 ,V2])
    
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels = data
        
        sz += inputs.shape[0]
        
        preds = torch.einsum('ab,bi,cj,ij->ac', inputs, V1, V2, G)
        acc1 += get_top1_accuracy(preds, labels) * inputs.shape[0]
        acc5 += get_top5_accuracy(preds, labels) * inputs.shape[0]
        
        loss += criterion(preds, labels)
    
    return loss / sz, acc1 / sz, acc5 / sz

def train(G, V1, V2, train_loader, test_loader, optimizer, criterion, n_epochs=10, wandb_log=False):
    if wandb_log:
        wandb.init('TRL symmetric')
    
    for i in range(n_epochs):
        loss, acc1, acc5 = eval(G, V1, V2, test_loader, criterion)
        if wandb_log:
            wandb.log({
                'test_loss': loss,
                'test_acc1': acc1,
                'test_acc5': acc5
            })
        print(f'Epoch {i}. loss={loss}, acc1={acc1}, acc5={acc5}')
        train_one_epoch(G, V1, V2, train_loader, optimizer, criterion, wandb_log=wandb_log)


In [None]:
device = 'cuda:0'
# X_train = torch.random(1024, 64, 8, 8).to(device)
# y_train = ((torch.randn(1024) > 100) * 1).to(device)

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

test_set = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [1]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.optim import SGDmomentum as SGDmomentum
from tucker_riemopt.tucker import Tucker as Tucker
from torch.optim import SGD
import wandb

c = 1

for c in [2, 3, 5, 10]:
    print(f'C = {c}')
    rank = (c, c)

    G, V1, V2 = get_GVU(c)
    for p in [G, V1, V2]:
        p.requires_grad = True
    print(f'N elems is {(G.numel() + V1.numel() + V2.numel())}')
    print(f'Compression is {(10 * 4096) / (G.numel() + V1.numel() + V2.numel())}')
    
    # optimizer = SGD(nn.ParameterList([G, V1, V2]), lr=1e-3, momentum=0.9)
    optimizer = SGD([G, V1, V2], lr=1e-3, momentum=0.9)

    train(G, V1, V2, train_loader, test_loader, optimizer, criterion, n_epochs=20, wandb_log=False)

## Flatten с римановой оптимизацией

In [None]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.symmetric.optim import SGDmomentum as SGDmomentumSym
from tucker_riemopt.symmetric.tucker import Tucker as SymTucker

import wandb

c = 1

rank = (10, 4 * c, 2 * c, 2 * c)

d_in = 64 * 8 * 8
d_out = 10

# model = TuckerLinearSimpleSymmetric(d_in, d_out, nn.Linear(d_in, d_out), rank=rank, dims=[10, 64, 8, 8])

# optimizer = SGDmomentumSym(model.riemann_parameters(), rank=rank, max_lr=1e-3)
criterion = nn.CrossEntropyLoss()

def get_GVU(c):
    rank = (c, c)
    
    U1, _, _ = torch.linalg.svd(torch.randn(4096, rank[0]))
    V1= U1[:, :rank[0]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(10, rank[1]))
    V2 = U1[:, :rank[1]].to(device)

    G = torch.randn(rank).to(device) / 100000
    
    return G, V1, V2

# x_k = SymTucker(G, [V1, V2], num_symmetric_modes=2, symmetric_factor=U)

# optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, U]), rank=rank, max_lr=1e-3)

def forward(X, input):
    output = torch.einsum('ab,bi,cj,ij->ac', input, *X.factors, X.core)
    return output

In [None]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels[:, None]).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels[:, None]).sum() / preds.shape[0]

def train_one_epoch(G, V1, V2, train_loader, optimizer, criterion, wandb_log):
    train_grad_norm = 0
    train_loss = 0
    
    log_every = 100
    
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        # x_k = SymTucker(model.core.data, [model.factors[0], model.factors[1]], num_symmetric_modes=2, symmetric_factor=model.factors[-1])
        x_k = Tucker(G, [V1 ,V2])
        # x_k = Tucker(model.core.data, [model.R.weight, model.S.weight, model.O.weight])

        grad_norm = optimizer.fit(loss_fn, x_k)
        optimizer.step()
        train_grad_norm += grad_norm.detach()
        train_loss += optimizer.loss.detach()

        optimizer.zero_grad(set_to_none=True)
        
        if (i + 1) % log_every == 0:
            if wandb_log:
                wandb.log({
                    'train_grad_norm': train_grad_norm / log_every,
                    'train_loss': train_loss / log_every
                })
            else:
                print('Grad norm:', train_grad_norm / log_every)
                print('Train loss:', train_loss / log_every)
                print()
            train_grad_norm = 0
            train_loss = 0
            pass

def eval(G, V1, V2, test_loader, criterion):
    loss = 0
    acc1 = 0
    acc5 = 0
    sz = 0
    
    x = Tucker(G, [V1 ,V2])
    
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels = data
        
        sz += inputs.shape[0]
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        preds = forward(x, inputs)
        acc1 += get_top1_accuracy(preds, labels) * inputs.shape[0]
        acc5 += get_top5_accuracy(preds, labels) * inputs.shape[0]
        
        loss += loss_fn(x)
    
    return loss / sz, acc1 / sz, acc5 / sz

def train(G, V1, V2, train_loader, test_loader, optimizer, criterion, n_epochs=10, wandb_log=False):
    if wandb_log:
        wandb.init('TRL symmetric')
    
    for i in range(n_epochs):
        loss, acc1, acc5 = eval(G, V1, V2, test_loader, criterion)
        if wandb_log:
            wandb.log({
                'test_loss': loss,
                'test_acc1': acc1,
                'test_acc5': acc5
            })
        print(f'Epoch {i}. loss={loss}, acc1={acc1}, acc5={acc5}')
        train_one_epoch(G, V1, V2, train_loader, optimizer, criterion, wandb_log=wandb_log)


In [None]:
device = 'cuda:0'
# X_train = torch.random(1024, 64, 8, 8).to(device)
# y_train = ((torch.randn(1024) > 100) * 1).to(device)

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

test_set = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [None]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.optim import SGDmomentum as SGDmomentum
from tucker_riemopt.tucker import Tucker as Tucker
from torch.optim import SGD
import wandb

c = 1

for c in [2, 3, 5, 10]:
    print(f'C = {c}')
    rank = (c, c)

    G, V1, V2 = get_GVU(c)
    print(f'N elems is {(G.numel() + V1.numel() + V2.numel())}')
    print(f'Compression is {(10 * 4096) / (G.numel() + V1.numel() + V2.numel())}')
    
    for p in [G, V1, V2]:
        p.requires_grad = True
    optimizer = SGDmomentum(nn.ParameterList([G, V1, V2]), rank=rank, max_lr=1e-4)

    train(G, V1, V2, train_loader, test_loader, optimizer, criterion, n_epochs=20, wandb_log=True)

## Tucker без римановой оптимизации

In [None]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.optim import SGDmomentum as SGDmomentum
from tucker_riemopt.tucker import Tucker as Tucker

def get_GV(c):
    rank = (10, 4 * c, 2 * c, 2 * c)
    
    U1, _, _ = torch.linalg.svd(torch.randn(10, rank[0]))
    V1= U1[:, :rank[0]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(64, rank[1]))
    V2 = U1[:, :rank[1]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(8, rank[2]))
    V3 = U1[:, :rank[2]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(8, rank[3]))
    V4 = U1[:, :rank[3]].to(device)

    G = torch.randn(rank).to(device) / 100000
    
    G = nn.Parameter(G)
    V1 = nn.Parameter(V1)
    V2 = nn.Parameter(V2)
    V3 = nn.Parameter(V3)
    V4 = nn.Parameter(V4)
    
    return G, V1, V2, V3, V4

# x_k = SymTucker(G, [V1, V2], num_symmetric_modes=2, symmetric_factor=U)

# optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, V3, V4]), rank=rank, max_lr=1e-3)

In [None]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels[:, None]).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels[:, None]).sum() / preds.shape[0]

def train_one_epoch(G, V1, V2, V3, V4, train_loader, optimizer, criterion, wandb_log):
    train_grad_norm = 0
    train_loss = 0
    
    log_every = 100
    
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data
        
        optimizer.zero_grad()
        preds = torch.einsum('abcd,ih,be,cf,dg,hefg->ai', inputs, V1, V2, V3, V4, G)
        loss = criterion(preds, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        if (i + 1) % log_every == 0:
            if wandb_log:
                wandb.log({
                    'train_loss': train_loss / log_every
                })
            else:
                print('Train loss:', train_loss / log_every)
                print()
            train_loss = 0
            pass

def eval(G, V1, V2, V3, V4, test_loader, criterion):
    loss = 0
    acc1 = 0
    acc5 = 0
    sz = 0
    
    x = Tucker(G, [V1 ,V2, V3, V4])
    
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels = data
        
        sz += inputs.shape[0]
        
        preds = torch.einsum('abcd,ih,be,cf,dg,hefg->ai', inputs, V1, V2, V3, V4, G)
        acc1 += get_top1_accuracy(preds, labels) * inputs.shape[0]
        acc5 += get_top5_accuracy(preds, labels) * inputs.shape[0]
        
        loss += criterion(preds, labels)
    
    return loss / sz, acc1 / sz, acc5 / sz

def train(G, V1, V2, V3, V4, train_loader, test_loader, optimizer, criterion, n_epochs=10, wandb_log=False):
    if wandb_log:
        wandb.init('TRL symmetric')
    
    for i in range(n_epochs):
        loss, acc1, acc5 = eval(G, V1, V2, V3, V4, test_loader, criterion)
        if wandb_log:
            wandb.log({
                'test_loss': loss,
                'test_acc1': acc1,
                'test_acc5': acc5
            })
        print(f'Epoch {i}. loss={loss}, acc1={acc1}, acc5={acc5}')
        train_one_epoch(G, V1, V2, V3, V4, train_loader, optimizer, criterion, wandb_log=wandb_log)


In [None]:
device = 'cuda:0'
# X_train = torch.random(1024, 64, 8, 8).to(device)
# y_train = ((torch.randn(1024) > 100) * 1).to(device)

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

test_set = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [None]:
from torch.optim import SGD
import wandb

c = 1

for c in [1, 2, 3, 4, 5]:
    print(f'C = {c}')
    rank = (10, 4 * c, 2 * c, 2 * c)

    G, V1, V2, V3, V4 = get_GV(c)
    optimizer = SGD(nn.ParameterList([G, V1, V2, V3, V4]), lr=1e-3, momentum=0.9)

    train(G, V1, V2, V3, V4, train_loader, test_loader, optimizer, criterion, n_epochs=20, wandb_log=True)

## Обычный Таккер

In [None]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.optim import SGDmomentum as SGDmomentum
from tucker_riemopt.tucker import Tucker as Tucker

def get_GV(c):
    rank = (10, 4 * c, 2 * c, 2 * c)
    
    U1, _, _ = torch.linalg.svd(torch.randn(10, rank[0]))
    V1= U1[:, :rank[0]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(64, rank[1]))
    V2 = U1[:, :rank[1]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(8, rank[2]))
    V3 = U1[:, :rank[2]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(8, rank[3]))
    V4 = U1[:, :rank[3]].to(device)

    G = torch.randn(rank).to(device) / 100000
    
    return G, V1, V2, V3, V4

# x_k = SymTucker(G, [V1, V2], num_symmetric_modes=2, symmetric_factor=U)

# optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, V3, V4]), rank=rank, max_lr=1e-3)

def forward(X, input):
    output = torch.einsum('abcd,ih,be,cf,dg,hefg->ai', input, *X.factors, X.core)
    return output

In [None]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels[:, None]).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels[:, None]).sum() / preds.shape[0]

def train_one_epoch(G, V1, V2, V3, V4, train_loader, optimizer, criterion, wandb_log):
    train_grad_norm = 0
    train_loss = 0
    
    log_every = 100
    
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        # x_k = SymTucker(model.core.data, [model.factors[0], model.factors[1]], num_symmetric_modes=2, symmetric_factor=model.factors[-1])
        x_k = Tucker(G, [V1 ,V2, V3, V4])
        # x_k = Tucker(model.core.data, [model.R.weight, model.S.weight, model.O.weight])

        grad_norm = optimizer.fit(loss_fn, x_k)
        optimizer.step()
        train_grad_norm += grad_norm.detach()
        train_loss += optimizer.loss.detach()

        optimizer.zero_grad(set_to_none=True)
        
        if (i + 1) % log_every == 0:
            if wandb_log:
                wandb.log({
                    'train_grad_norm': train_grad_norm / log_every,
                    'train_loss': train_loss / log_every
                })
            else:
                print('Grad norm:', train_grad_norm / log_every)
                print('Train loss:', train_loss / log_every)
                print()
            train_grad_norm = 0
            train_loss = 0
            pass

def eval(G, V1, V2, V3, V4, test_loader, criterion):
    loss = 0
    acc1 = 0
    acc5 = 0
    sz = 0
    
    x = Tucker(G, [V1 ,V2, V3, V4])
    
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels = data
        
        sz += inputs.shape[0]
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        preds = forward(x, inputs)
        acc1 += get_top1_accuracy(preds, labels) * inputs.shape[0]
        acc5 += get_top5_accuracy(preds, labels) * inputs.shape[0]
        
        loss += loss_fn(x)
    
    return loss / sz, acc1 / sz, acc5 / sz

def train(G, V1, V2, V3, V4, train_loader, test_loader, optimizer, criterion, n_epochs=10, wandb_log=False):
    if wandb_log:
        wandb.init('TRL symmetric')
    
    for i in range(n_epochs):
        loss, acc1, acc5 = eval(G, V1, V2, V3, V4, test_loader, criterion)
        if wandb_log:
            wandb.log({
                'test_loss': loss,
                'test_acc1': acc1,
                'test_acc5': acc5
            })
        print(f'Epoch {i}. loss={loss}, acc1={acc1}, acc5={acc5}')
        train_one_epoch(G, V1, V2, V3, V4, train_loader, optimizer, criterion, wandb_log=wandb_log)


In [None]:
device = 'cuda:0'
# X_train = torch.random(1024, 64, 8, 8).to(device)
# y_train = ((torch.randn(1024) > 100) * 1).to(device)

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

test_set = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [None]:
from torch.optim import SGD
import wandb

c = 1

for c in [1, 2, 3, 4, 5]:
    print(f'C = {c}')
    rank = (10, 4 * c, 2 * c, 2 * c)

    G, V1, V2, V3, V4 = get_GV(c)
    for p in [G, V1, V2, V3, V4]:
        p.requires_grad = True
    optimizer = SGDmomentum(nn.ParameterList([G, V1, V2, V3, V4]), rank=rank, max_lr=1e-4)

    train(G, V1, V2, V3, V4, train_loader, test_loader, optimizer, criterion, n_epochs=20, wandb_log=True)

## SF-Tucker

In [20]:
from copy import deepcopy
from tucker_riemopt.tucker import Tucker
from tucker_riemopt import backend as back

back.set_backend('pytorch')

model_orig = deepcopy(model)

In [21]:
from torch.utils.data import Dataset, TensorDataset, DataLoader
from src.TuckerLinear import TuckerLinearSimpleSymmetric
from tucker_riemopt.symmetric.optim import SGDmomentum as SGDmomentumSym
from tucker_riemopt.symmetric.tucker import Tucker as SymTucker

import wandb

c = 1

rank = (10, 4 * c, 2 * c, 2 * c)

d_in = 64 * 8 * 8
d_out = 10

# model = TuckerLinearSimpleSymmetric(d_in, d_out, nn.Linear(d_in, d_out), rank=rank, dims=[10, 64, 8, 8])

# optimizer = SGDmomentumSym(model.riemann_parameters(), rank=rank, max_lr=1e-3)
criterion = nn.CrossEntropyLoss()

def get_GVU(c):
    rank = (10, 4 * c, 2 * c, 2 * c)
    
    U1, _, _ = torch.linalg.svd(torch.randn(10, rank[0]))
    V1= U1[:, :rank[0]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(64, rank[1]))
    V2 = U1[:, :rank[1]].to(device)

    U1, _, _ = torch.linalg.svd(torch.randn(8, rank[2]))
    U = U1[:, :rank[2]].to(device)

    G = torch.randn(rank).to(device) / 100000
    
    return G, V1, V2, U

# x_k = SymTucker(G, [V1, V2], num_symmetric_modes=2, symmetric_factor=U)

# optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, U]), rank=rank, max_lr=1e-3)

def forward(X, input):
    output = torch.einsum('abcd,ih,be,cf,dg,hefg->ai', input, X.common_factors[0], X.common_factors[1], X.symmetric_factor, X.symmetric_factor, X.core)
    return output

In [22]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels[:, None]).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels[:, None]).sum() / preds.shape[0]

def train_one_epoch(G, V1, V2, U, train_loader, optimizer, criterion, wandb_log):
    train_grad_norm = 0
    train_loss = 0
    
    log_every = 100
    
    for i, data in enumerate(tqdm(train_loader)):
        inputs, labels = data
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        # x_k = SymTucker(model.core.data, [model.factors[0], model.factors[1]], num_symmetric_modes=2, symmetric_factor=model.factors[-1])
        x_k = SymTucker(G, [V1 ,V2], num_symmetric_modes=2, symmetric_factor=U)
        # x_k = Tucker(model.core.data, [model.R.weight, model.S.weight, model.O.weight])

        grad_norm = optimizer.fit(loss_fn, x_k)
        optimizer.step()
        train_grad_norm += grad_norm.detach()
        train_loss += optimizer.loss.detach()

        optimizer.zero_grad(set_to_none=True)
        
        if (i + 1) % log_every == 0:
            if wandb_log:
                wandb.log({
                    'train_grad_norm': train_grad_norm / log_every,
                    'train_loss': train_loss / log_every
                })
            else:
                print('Grad norm:', train_grad_norm / log_every)
                print('Train loss:', train_loss / log_every)
                print()
            train_grad_norm = 0
            train_loss = 0
            pass

def eval(G, V1, V2, U, test_loader, criterion):
    loss = 0
    acc1 = 0
    acc5 = 0
    sz = 0
    
    x = SymTucker(G, [V1 ,V2], num_symmetric_modes=2, symmetric_factor=U)
    
    for i, data in enumerate(tqdm(test_loader)):
        inputs, labels = data
        
        sz += inputs.shape[0]
        
        def loss_fn(x):
            return criterion(forward(x, inputs), labels)
        
        preds = forward(x, inputs)
        acc1 += get_top1_accuracy(preds, labels) * inputs.shape[0]
        acc5 += get_top5_accuracy(preds, labels) * inputs.shape[0]
        
        loss += loss_fn(x)
    
    return loss / sz, acc1 / sz, acc5 / sz

def train(G, V1, V2, U, train_loader, test_loader, optimizer, criterion, n_epochs=10, wandb_log=False):
    if wandb_log:
        wandb.init('TRL symmetric')
    
    for i in range(n_epochs):
        loss, acc1, acc5 = eval(G, V1, V2, U, test_loader, criterion)
        if wandb_log:
            wandb.log({
                'test_loss': loss,
                'test_acc1': acc1,
                'test_acc5': acc5
            })
        print(f'Epoch {i}. loss={loss}, acc1={acc1}, acc5={acc5}')
        train_one_epoch(G, V1, V2, U, train_loader, optimizer, criterion, wandb_log=wandb_log)


In [24]:
device = 'cuda:0'
# X_train = torch.random(1024, 64, 8, 8).to(device)
# y_train = ((torch.randn(1024) > 100) * 1).to(device)

train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

test_set = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [24]:
print(len(train_set))
print(len(test_set))
print(X_train.shape)
print(X_test.shape)

50000
10000
torch.Size([50000, 64, 8, 8])
torch.Size([10000, 64, 8, 8])


In [25]:
c = 1

rank = (10, 4 * c, 2 * c, 2 * c)

G, V1, V2, U = get_GVU(c)
optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, U]), rank=rank, max_lr=1e-3)

In [1]:
c = 1

for c in [1, 2, 3, 4]:
    print(f'C = {c}')
    rank = (10, 4 * c, 2 * c, 2 * c)

    G, V1, V2, U = get_GVU(c)
    optimizer = SGDmomentumSym(nn.ParameterList([G, V1, V2, U]), rank=rank, max_lr=1e-3)

    train(G, V1, V2, U, train_loader, test_loader, optimizer, criterion, n_epochs=20, wandb_log=True)