## Подготовка данных

In [1]:
%load_ext autoreload
%autoreload 2

import torch

import jax.numpy as jnp
import scipy
import copy
import sys

In [2]:
import torch
from torch import nn
from sklearn.preprocessing import StandardScaler
import numpy as np

from torchvision.models import resnet18
from torchvision.models import resnet50
from torchvision.datasets import ImageNet

In [3]:
import torchvision.transforms as T

train_set = np.load('data/train_data_batch_1', allow_pickle=True)
test_set = np.load('data/val_data', allow_pickle=True)

X_train = torch.Tensor(train_set['data'].reshape(-1, 3, 32, 32))
y_train = np.array(train_set['labels']) - 1

X_test = torch.Tensor(test_set['data'].reshape(-1, 3, 32, 32))
y_test = np.array(test_set['labels']) - 1

transforms = torch.nn.Sequential(
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
)

device = 'cpu'

X_train = transforms(X_train).to(device)
X_test = transforms(X_test).to(device)

In [8]:
from torch.utils.data import TensorDataset, Dataset, DataLoader

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

train_dataloader = DataLoader(train_dataset, batch_size=128)
val_dataloader = DataLoader(val_dataset, batch_size=128)

  train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
  val_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))


In [9]:
resnet = resnet18(pretrained=True)
resnet.fc = nn.Identity()

for param in resnet.parameters():
    param.requires_grad = False

## Обычный Linear

In [10]:
from tqdm.notebook import tqdm

def train_one_epoch(model, train_dataloader, criterion, optimizer, device="cuda:0"):
    progress_bar = tqdm(train_dataloader)
    model = model.to(device).train()
    idx = 0
    for (images, labels) in progress_bar:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx % 10 == 0:
            progress_bar.set_description("Loss = {:.4f}".format(loss.item()))
        idx += 1


def predict(model, val_dataloader, criterion, device="cuda:0"):
    cumulative_loss = 0
    top1_acc = 0
    top5_acc = 0
    model = model.to(device).eval()
    predicted_classes = []
    true_classes = []
    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_dataloader): 
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            predicted_classes.append(preds.argmax(1).float())
            true_classes.append(labels)
            cumulative_loss += loss.item()
            top1_acc += (preds.argsort(axis=1)[:,-1:].T == labels).float().sum()
            top5_acc += (preds.argsort(axis=1)[:,-5:].T == labels).float().sum()
    print(top1_acc)
    print(top5_acc)
    print("Loss = {:.4f}".format(cumulative_loss / idx), "top1 accuracy = {:.4f}".format(top1_acc / len(val_dataloader.dataset)), "top5 accuracy = {:.4f}".format(top5_acc / len(val_dataloader.dataset)))
    return cumulative_loss, torch.cat(predicted_classes).cpu(), torch.cat(true_classes).cpu()


def train(model, train_dataloader, val_dataloader, criterion, optimizer, device="cuda:0", n_epochs=10, scheduler=None):
    model = model.to(device)
    for epoch in range(n_epochs):
        train_one_epoch(model, train_dataloader, criterion, optimizer, device)
        loss, _, _ = predict(model, val_dataloader, criterion, device)
        if scheduler is not None:
            scheduler.step(loss)

In [11]:
model = resnet18(pretrained=True)
model.fc = nn.Linear(512, 1000)

for param in model.parameters():
    param.requires_grad = False

for param in model.fc.parameters():
    param.requires_grad = True

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), 1e-4)
criterion = nn.CrossEntropyLoss()

In [12]:
train(model, train_dataloader, val_dataloader, criterion, optimizer, device, n_epochs=5)

  0%|          | 0/1001 [00:00<?, ?it/s]

tensor(583.)
tensor(1900.)
Loss = 6.7273 top1 accuracy = 0.0117 top5 accuracy = 0.0380


  0%|          | 0/1001 [00:00<?, ?it/s]

tensor(1130.)
tensor(3292.)
Loss = 6.5025 top1 accuracy = 0.0226 top5 accuracy = 0.0658


  0%|          | 0/1001 [00:00<?, ?it/s]

tensor(1396.)
tensor(4037.)
Loss = 6.3814 top1 accuracy = 0.0279 top5 accuracy = 0.0807


  0%|          | 0/1001 [00:00<?, ?it/s]

tensor(1543.)
tensor(4466.)
Loss = 6.3102 top1 accuracy = 0.0309 top5 accuracy = 0.0893


  0%|          | 0/1001 [00:00<?, ?it/s]

tensor(1665.)
tensor(4747.)
Loss = 6.2659 top1 accuracy = 0.0333 top5 accuracy = 0.0949


In [13]:
predict(model, val_dataloader, criterion, device)
None

tensor(1665.)
tensor(4747.)
Loss = 6.2659 top1 accuracy = 0.0333 top5 accuracy = 0.0949


## Таккер без римановых оптимизаций

In [14]:
def get_top1_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-1:] == labels.reshape(-1, 1)).sum() / preds.shape[0]

def get_top5_accuracy(preds, labels):
    return (preds.argsort(axis=1)[:,-5:] == labels.reshape(-1, 1) @ np.ones(5).reshape(1, -1)).sum() / preds.shape[0]

In [15]:
def optimize_GD(f, X0, maxiter=10, lr=1e-3, verbose=False):
    """
    Input
        f: function to maximize
        X0: first approximation
        maxiter: number of iterations to perform

    Output
        Xk: approximation after maxiter iterations
        errs: values of functional on each step
    """
    X = X0
    max_rank = np.max(X.rank)

    errs = []
    errs.append(f(X))

    @jax.jit
    def g(core, factors):
        T = Tucker(core, factors)
        return f(T)

    dg_dS = jax.grad(g, argnums=0)
    dg_dU = jax.grad(g, argnums=1)
    
    for i in range(maxiter):
        dS = dg_dS(X.core, X.factors)
        dU = dg_dU(X.core, X.factors)

        tau = lr

        new_core = X.core + tau * dS
        new_factors = []

        for i in range(X.ndim):
            new_factors.append(X.factors[i] + tau * dU[i])

        X = Tucker(new_core, new_factors)
        
        errs.append(f(X))

        if verbose:
            print(f'Done iteration {i+1}/{maxiter}!\t Error: {errs[-1]}' + ' ' * 50, end='\n')

    return X, errs

In [16]:
from tucker_riemopt.src.noriemopt import optimize_no_riemopt
from tucker_riemopt.src.riemopt import optimize
from tucker_riemopt.src.matrix import TuckerMatrix
from tucker_riemopt.src.tucker import Tucker

import jax
import jax.config
jax.config.update("jax_enable_x64", True)

from tqdm.notebook import tqdm

def step_GD(images, labels, X, lr=1e-3):
    embeds = jnp.array(resnet(images))
    
    @jax.jit
    def f(T):
        preds = T.k_mode_product(0, embeds).full()
        probs = jnp.exp(preds)
        probs = (probs.T / probs.sum(axis=1)).T
        return -jnp.sum(-jnp.log(probs[np.arange(labels.shape[0]), labels]))
    
    X, _ = optimize_GD(f, X, maxiter=100, lr=lr)
    return X

def predict(images, X):
    res = []
    for i in tqdm(range((images.shape[0] - 1) // 500 + 1)):
        embeds = jnp.array(resnet(images[i * 500 : (i + 1) * 500]))

        preds = X.k_mode_product(0, embeds).full()
        res.append(preds)
    return np.vstack(res)

In [17]:
x = np.random.randn(512, 1000) / 10000
X = Tucker.full2tuck(x)



In [22]:
for i in tqdm(range((X_train.shape[0] - 1) // 500 + 1)):
    X = step_GD(X_train[i * 500 : (i + 1) * 500], y_train[i * 500 : (i + 1) * 500], X, lr=1e-4)
    if i % 64 == 0:
        preds = predict(X_test, X)
        print(f'Top1 accuracy: {get_top1_accuracy(preds, y_test)}')
        print(f'Top5 accuracy: {get_top5_accuracy(preds, y_test)}')

  0%|          | 0/257 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00112
Top5 accuracy: 0.005


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.0045
Top5 accuracy: 0.01978


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00542
Top5 accuracy: 0.02148


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00616
Top5 accuracy: 0.0237


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.0064
Top5 accuracy: 0.02434


In [23]:
preds = predict(X_test, X)
print(f'Top1 accuracy: {get_top1_accuracy(preds, y_test)}')
print(f'Top5 accuracy: {get_top5_accuracy(preds, y_test)}')

  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.0064
Top5 accuracy: 0.02434


## Таккер с римановыми оптимизациями

In [28]:
from tucker_riemopt.src.noriemopt import optimize_no_riemopt
from tucker_riemopt.src.riemopt import optimize
from tucker_riemopt.src.matrix import TuckerMatrix
from tucker_riemopt.src.tucker import Tucker

import jax
import jax.config
jax.config.update("jax_enable_x64", True)

from tqdm.notebook import tqdm

def step_riemopt(images, labels, X):
    embeds = jnp.array(resnet(images))
    
    @jax.jit
    def f(T):
        preds = T.k_mode_product(0, embeds).full()
        probs = jnp.exp(preds)
        probs = (probs.T / probs.sum(axis=1)).T
        return -jnp.sum(-jnp.log(probs[np.arange(labels.shape[0]), labels]))
    
    X, _ = optimize(f, X, maxiter=100)
    return X

In [29]:
x = np.random.randn(512, 10) / 10000
X = Tucker.full2tuck(x)

In [30]:
for i in tqdm(range((X_train.shape[0] - 1) // 500 + 1)):
    X = step_riemopt(X_train[i * 500 : (i + 1) * 500], y_train[i * 500 : (i + 1) * 500], X)
    if i % 64 == 0:
        preds = predict(X_test, X)
        print(f'Top1 accuracy: {get_top1_accuracy(preds, y_test)}')
        print(f'Top5 accuracy: {get_top5_accuracy(preds, y_test)}')

  0%|          | 0/257 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00106
Top5 accuracy: 0.00476


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00186
Top5 accuracy: 0.0062


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00198
Top5 accuracy: 0.0069


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.0025
Top5 accuracy: 0.0064


  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00234
Top5 accuracy: 0.00682


In [31]:
preds = predict(X_test, X)
print(f'Top1 accuracy: {get_top1_accuracy(preds, y_test)}')
print(f'Top5 accuracy: {get_top5_accuracy(preds, y_test)}')

  0%|          | 0/100 [00:00<?, ?it/s]

Top1 accuracy: 0.00234
Top5 accuracy: 0.00682
