In [1]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

In [2]:
from torchvision import datasets, transforms

In [3]:
from senmodel.model.utils import convert_dense_to_sparse_network, get_model_last_layer
from senmodel.metrics.edge_finder import EdgeFinder


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [5]:
BATCH_SIZE = 64

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

train_dataset = datasets.CIFAR10(root='./data', train=True,
                                  download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False,
                                  download=True, transform=transform)

# train_size = int(0.8 * len(train_dataset))
# val_size = len(train_dataset) - train_size
# # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# train_dataset, val_dataset, test_dataset = random_split(train_dataset, [train_size // 2, val_size // 2, len(train_dataset) - (train_size // 2 + val_size // 2)])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(len(train_loader))

Files already downloaded and verified
Files already downloaded and verified
782


In [6]:
class ExpandingHead(nn.Module):
    def __init__(self, input_size: int = 64, hidden_size: int = 50, output_size: int = 10):
        super().__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [7]:
class ResnetExp(nn.Module):
    def __init__(self, freeze_base: bool = False):
        super().__init__()
        self.base_model = torch.hub.load("chenyaofo/pytorch-cifar-models",
                                         "cifar10_resnet20", pretrained=True)
        self.base_model = torch.nn.Sequential(
            *(list(self.base_model.children())[:-1])
        )
        self.expanding_head = convert_dense_to_sparse_network(
            ExpandingHead(input_size=64, hidden_size=50, output_size=10)
        ).to(device)
        # if freeze_base:
        #     self.freeze(self.base_model)

    def forward(self, x):
        x = self.base_model(x)
        x = x.view(x.size(0), -1)
        x = self.expanding_head(x)
        return x

In [8]:
rexp = ResnetExp()
rexp = rexp.to(device)
img = val_dataset[0][0].unsqueeze(0).to(device)
rexp(img)

Using cache found in C:\Users\fedor/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


tensor([[-0.2109, -0.1162, -0.1401, -0.0713,  0.1549, -0.1437,  0.2410,  0.0094,
         -0.0911,  0.0005]], grad_fn=<AsStridedBackward0>)

In [9]:
def edge_replacement_func_new_layer(model, optim, val_loader, metric, choose_threshold, aggregation_mode='mean', len_choose=None):
    layer = get_model_last_layer(model)
    ef = EdgeFinder(metric, val_loader, device, aggregation_mode)
    vals = ef.calculate_edge_metric_for_dataloader(model, len_choose, False)
    print("Edge metrics:", vals, max(vals, default=0), sum(vals))
    chosen_edges = ef.choose_edges_threshold(model, choose_threshold, len_choose)
    print("Chosen edges:", chosen_edges, len(chosen_edges[0]))
    layer.replace_many(*chosen_edges)

    if len(chosen_edges[0]) > 0:
        optim.add_param_group({'params': layer.embed_linears[-1].weight_values})
    else:
        print("Empty metric")

    return {'max': max(vals, default=0), 'sum': sum(vals), 'len': len(vals), 'len_choose': layer.count_replaces[-1]}

In [10]:
from senmodel.model.utils import freeze_all_but_last, freeze_only_last
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import accuracy_score


def train_sparse_recursive(model, train_loader, val_loader, num_epochs, metric, edge_replacement_func=None,
                           window_size=3, threshold=0.1, lr=5e-4, choose_threshold=0.3, aggregation_mode='mean', replace_all_epochs=3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    val_losses = []

    len_choose = get_model_last_layer(model).count_replaces

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()


        for i, (inputs, targets) in enumerate(tqdm(train_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if len(len_choose) > replace_all_epochs and i > window_size:
                freeze_all_but_last(model)

            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_accuracy = accuracy_score(all_targets, all_preds)
        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}")


        new_l = dict()

        val_losses.append(val_loss)
        if edge_replacement_func and len(val_losses) > window_size:
            recent_changes = [abs(val_losses[i] - val_losses[i - 1]) for i in range(-window_size, 0)]
            avg_change = sum(recent_changes) / window_size
            if avg_change < threshold:
                print(f"{len_choose=}")
                len_ch = len_choose[-1] if len(len_choose) > replace_all_epochs else None
                new_l = edge_replacement_func(model, optimizer, val_loader, metric, choose_threshold, aggregation_mode, len_ch)
                # Замораживаем все слои кроме последнего
                val_losses = []
                len_choose = get_model_last_layer(model).count_replaces

        wandb.log({'val_loss': val_loss, 'val_accuracy': val_accuracy, 'train_loss': train_loss} | new_l)

In [11]:
from senmodel.metrics.nonlinearity_metrics import *

criterion = nn.CrossEntropyLoss()
metrics = [
    AbsGradientEdgeMetric(criterion),
    ReversedAbsGradientEdgeMetric(criterion),
    SNIPMetric(criterion),
    MagnitudeL2Metric(criterion),
]


In [12]:
import wandb

wandb.login()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: fedornigretuk. Use `wandb login --relogin` to force relogin


True

In [13]:
hyperparams = {"num_epochs": 50,
               "metric": metrics[0],
               "aggregation_mode": "mean",
               "choose_threshold": 0.1,
               "window_size": 3,
               "threshold": 0.05,
               "lr": 5e-4,
               "replace_all_epochs": 2
               }

name = ", ".join(
    f"{key}: {value.__class__.__name__ if key == 'metric' else value}"
    for key, value in hyperparams.items()
)
name

'num_epochs: 50, metric: AbsGradientEdgeMetric, aggregation_mode: mean, choose_threshold: 0.1, window_size: 3, threshold: 0.05, lr: 0.0005, replace_all_epochs: 2'

In [15]:

sparse_model = ResnetExp()
wandb.finish()
wandb.init(
    project="self-expanding-nets",
    name=f"cifar10, {name}",
    tags=["complex model", hyperparams["metric"].__class__.__name__],
    group="new freeze 2"
)

train_sparse_recursive(sparse_model, train_loader, val_loader,
                       edge_replacement_func=edge_replacement_func_new_layer, **hyperparams)
wandb.finish()

Using cache found in C:\Users\fedor/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


100%|██████████| 782/782 [03:51<00:00,  3.38it/s]


Epoch 1/50 | Train Loss: 0.8213 | Val Loss: 66.4344 | Val Accuracy: 0.8744


100%|██████████| 782/782 [03:02<00:00,  4.29it/s]


Epoch 2/50 | Train Loss: 0.4196 | Val Loss: 64.7578 | Val Accuracy: 0.8800


100%|██████████| 782/782 [03:00<00:00,  4.34it/s]


Epoch 3/50 | Train Loss: 0.3539 | Val Loss: 62.6145 | Val Accuracy: 0.8845


100%|██████████| 782/782 [02:54<00:00,  4.49it/s]


Epoch 4/50 | Train Loss: 0.3072 | Val Loss: 62.2218 | Val Accuracy: 0.8909


100%|██████████| 782/782 [02:55<00:00,  4.44it/s]


Epoch 5/50 | Train Loss: 0.2884 | Val Loss: 77.0360 | Val Accuracy: 0.8692


100%|██████████| 782/782 [03:01<00:00,  4.31it/s]


Epoch 6/50 | Train Loss: 0.2631 | Val Loss: 61.0746 | Val Accuracy: 0.8902


100%|██████████| 782/782 [02:58<00:00,  4.37it/s]


Epoch 7/50 | Train Loss: 0.2422 | Val Loss: 75.6958 | Val Accuracy: 0.8787


100%|██████████| 782/782 [02:56<00:00,  4.44it/s]


Epoch 8/50 | Train Loss: 0.2240 | Val Loss: 68.7817 | Val Accuracy: 0.8926


100%|██████████| 782/782 [03:00<00:00,  4.34it/s]


Epoch 9/50 | Train Loss: 0.2188 | Val Loss: 66.4199 | Val Accuracy: 0.8940


  6%|▋         | 49/782 [00:11<02:57,  4.13it/s]


KeyboardInterrupt: 