In [3]:
%load_ext autoreload

%autoreload 2

In [4]:
import numpy as np
import torch
import torchvision
import tqdm
from torch import nn
from torchvision import transforms as T
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import random

In [5]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [6]:
from torchvision.datasets import CIFAR10

In [7]:
train_transform = T.Compose([
    T.RandomCrop((32, 32), padding=4),
    T.RandomHorizontalFlip(0.5),
    T.ColorJitter(contrast=0.25),
    T.ToTensor(),
    T.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
])
val_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
])

train_dataset = CIFAR10("./data/", download=True, train=True, transform=train_transform)
val_dataset = CIFAR10("./data/", download=True, train=False, transform=val_transform)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=100)

In [9]:
from src.utils import set_random_seed

In [10]:
from src.utils import train, predict

In [11]:
from src.layers.trl_masked import TRLMasked

set_random_seed(12345)
model = nn.Sequential(
    nn.Sequential(
        nn.Conv2d(3, 64, (3, 3), padding="same"),
        nn.BatchNorm2d(64),
        nn.ReLU(),
    ),
    nn.Sequential(
        nn.Conv2d(64, 64, (3, 3), padding="same"),
        nn.BatchNorm2d(64),
        nn.ReLU(),
    ),
    nn.MaxPool2d((3, 3), stride=(2, 2), padding=(1, 1)),
    nn.Sequential(
        nn.Conv2d(64, 128, (3, 3), padding="same"),
        nn.BatchNorm2d(128),
        nn.ReLU(),
    ),
    nn.Sequential(
        nn.Conv2d(128, 128, (3, 3), padding="same"),
        nn.BatchNorm2d(128),
        nn.ReLU(),
    ),
    nn.MaxPool2d((3, 3), stride=(2, 2), padding=(1, 1)),
    nn.Sequential(
        nn.Conv2d(128, 128, (3, 3), padding="same"),
        nn.BatchNorm2d(128),
        nn.ReLU(),
    ),
    nn.Sequential(
        nn.Conv2d(128, 128, (3, 3), padding="same"),
        nn.BatchNorm2d(128),
        nn.ReLU(),
    ),
    nn.Sequential(
        TRLMasked((128, 8, 8), (1, 1, 1), 10),
    )
)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
criterion = nn.CrossEntropyLoss(reduction="mean")
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1, verbose=True)
# scheduler = None
n_epochs = 100

device = "cuda:0" if torch.cuda.is_available() else torch.device("cpu")
# device = "cpu"
model = model.to(device)

Adjusting learning rate of group 0 to 1.0000e-02.


In [17]:
# model = torch.load("./models/test_model.pt")
# model[-1][0].clear_masks()

In [18]:
import tensornetwork as tn
import numpy as np
from copy import deepcopy

def append_to_axis(data, axis, tensor=None, device="cuda:0"):
    with torch.no_grad():
        std = torch.sqrt(torch.norm(data)**2 / np.prod(data.shape))
        if tensor is not None:
            return torch.cat((data, tensor), dim=axis)
        zeros_shape = list(data.shape)
        zeros_shape[axis] = 1
        return torch.cat((data, torch.normal(0, std, zeros_shape, device=device)), dim=axis)

def decrease_node(tensor, axis):
    with torch.no_grad():
        indices = [slice(None) for _ in tensor.data.shape]
        indices[axis] = slice(-1, None)
        tensor_slice = tensor.data[indices].detach().clone()
        indices[axis] = slice(-1)
        tensor.data = tensor.data[indices].clone()
        return tensor_slice

def decrease_edge(edge: tn.Edge, device="cuda:0"):
    return (decrease_node(edge.node1.tensor, edge.axis1), decrease_node(edge.node2.tensor, edge.axis2))

def increase_edge(edge: tn.Edge, tensors=(None, None), device="cuda:0"):
    with torch.no_grad():
        edge.node1.tensor.data = append_to_axis(edge.node1.tensor.data, edge.axis1, tensors[0], device=device)
        mask1 = torch.zeros_like(edge.node1.tensor.data, device="cpu", dtype=torch.long)
        mask1.index_fill_(edge.axis1, torch.LongTensor([-1]), 1)
        edge.node2.tensor.data = append_to_axis(edge.node2.tensor.data, edge.axis2, tensors[1], device=device)
        mask2 = torch.zeros_like(edge.node2.tensor.data, device="cpu", dtype=torch.long)
        mask2.index_fill_(edge.axis2, torch.LongTensor([-1]), 1)
        return mask1.to(device), mask2.to(device)

def increase_edge_in_layer(layer, edge, device="cuda:0"):
    with torch.no_grad():
        cloned_layer = deepcopy(layer)
        cloned_edge = [e for e in tn.get_all_nondangling(cloned_layer.construct_network()) 
            if e.node1.name == edge.node1.name and e.node2.name == edge.node2.name][0]
        cloned_layer.masks = {}
        mask1, mask2 =increase_edge(cloned_edge)
        cloned_layer.masks[edge.node1.name] = mask1
        cloned_layer.masks[edge.node2.name] = mask2

        return cloned_layer, cloned_edge


In [19]:
def choose_and_increase_edge(model, tensor_network, train_edge):
    edges = list(tn.get_all_nondangling(tensor_network))
    edges_results = []
    old_layer = model[-1][0]
    for edge in edges:
        model[-1][0], cloned_edge = increase_edge_in_layer(old_layer, edge)

        for param in model.parameters():
            param.requires_grad = False
        cloned_edge.node1.tensor.requires_grad = True
        cloned_edge.node2.tensor.requires_grad = True

        print("Train edge {}".format(edge.name))
        final_loss = train_edge(model)
        model[-1][0].clear_masks()
        edges_results.append((model[-1][0], final_loss, edge.name))
    model[-1][0], _, edgename = min(edges_results, key=lambda x: x[1])
    print("Choosen edge {}".format(edgename))

In [20]:
def train_edge(model):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
    train(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler=None, n_epochs=2, device=device)
    all_losses, _, _ = predict(model, train_dataloader, criterion, device)
    return np.sum(all_losses)

for i in range(30):
    for param in model.parameters():
        param.requires_grad = True
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
    print("Train full model")
    train(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler=None, n_epochs=1, device=device)
    choose_and_increase_edge(model, model[-1][0].construct_network(), train_edge)

Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.9307632489204407, Validation loss: 1.789718461036682, Validation accuracy: 0.2599
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.78416850566864, Validation loss: 1.7244648015499116, Validation accuracy: 0.2853


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.754244993686676, Validation loss: 1.7130224931240081, Validation accuracy: 0.2953
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.736805818080902, Validation loss: 1.6757553791999817, Validation accuracy: 0.3368


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.6982928307056426, Validation loss: 1.6620989179611205, Validation accuracy: 0.3518
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.8165272789001465, Validation loss: 1.749435863494873, Validation accuracy: 0.2633


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.7772827990055085, Validation loss: 1.7397427022457124, Validation accuracy: 0.3034
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.4866114370822907, Validation loss: 1.4027899777889252, Validation accuracy: 0.4712
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.3430445632934571, Validation loss: 1.308541386127472, Validation accuracy: 0.5229


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.3139754197597504, Validation loss: 1.304700391292572, Validation accuracy: 0.5181
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.3344928710460662, Validation loss: 1.291877233982086, Validation accuracy: 0.5289


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.2921081638336183, Validation loss: 1.268626824617386, Validation accuracy: 0.5367
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.3547518610954286, Validation loss: 1.308802956342697, Validation accuracy: 0.5117


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 1.3263384499549866, Validation loss: 1.2995866191387178, Validation accuracy: 0.5148
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 1.092422968864441, Validation loss: 1.0138880425691605, Validation accuracy: 0.6362
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.9707968008518219, Validation loss: 0.9291025811433792, Validation accuracy: 0.6644


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.9451832293272019, Validation loss: 0.9191816705465317, Validation accuracy: 0.6701
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.9900138903856277, Validation loss: 0.9615352499485016, Validation accuracy: 0.6488


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.9627003179788589, Validation loss: 0.9381575804948806, Validation accuracy: 0.6589
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.9808022487163544, Validation loss: 0.9393899095058441, Validation accuracy: 0.6614


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.9548983644247055, Validation loss: 0.9344148832559586, Validation accuracy: 0.6674
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.8762726247310638, Validation loss: 0.8761548915505409, Validation accuracy: 0.6887
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.7948588652610779, Validation loss: 0.8084081149101258, Validation accuracy: 0.7222


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.7720650094151497, Validation loss: 0.7979872846603393, Validation accuracy: 0.7289
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.823607914686203, Validation loss: 0.8160923933982849, Validation accuracy: 0.7229


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.7723347393274307, Validation loss: 0.8178819665312766, Validation accuracy: 0.7233
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.8085252425074577, Validation loss: 0.830833925306797, Validation accuracy: 0.7156


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.7787712095975876, Validation loss: 0.809297656416893, Validation accuracy: 0.7257
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.7399880170822144, Validation loss: 0.8149229687452316, Validation accuracy: 0.7394
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.6617720746397973, Validation loss: 0.6928267621994019, Validation accuracy: 0.7708


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.6392206225991249, Validation loss: 0.6872291287779808, Validation accuracy: 0.7715
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.6602500823140144, Validation loss: 0.7043652325868607, Validation accuracy: 0.7652


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.6568389899134636, Validation loss: 0.7040628093481064, Validation accuracy: 0.7647
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.6546978765130043, Validation loss: 0.7002277860045433, Validation accuracy: 0.7702


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.6491063430309295, Validation loss: 0.6968226647377014, Validation accuracy: 0.769
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.6353593991398812, Validation loss: 0.7591306313872337, Validation accuracy: 0.7533
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.6027866270542145, Validation loss: 0.6221936973929405, Validation accuracy: 0.792


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5850700612068176, Validation loss: 0.618477708697319, Validation accuracy: 0.7956
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.607025422334671, Validation loss: 0.6273889264464378, Validation accuracy: 0.7873


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5951623556017875, Validation loss: 0.6252178165316582, Validation accuracy: 0.7872
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5881635114550591, Validation loss: 0.6132561233639717, Validation accuracy: 0.7934


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5721756559610367, Validation loss: 0.6098536571860314, Validation accuracy: 0.7957
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5683269039392471, Validation loss: 0.8311353331804275, Validation accuracy: 0.7381
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5275884244441986, Validation loss: 0.5702491843700409, Validation accuracy: 0.8108


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5097877843976021, Validation loss: 0.5655735284090042, Validation accuracy: 0.8128
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5156317266821862, Validation loss: 0.5707490158081054, Validation accuracy: 0.8084


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5058294162750244, Validation loss: 0.560108018219471, Validation accuracy: 0.8114
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.528341213464737, Validation loss: 0.575655372440815, Validation accuracy: 0.809


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.5142800963521004, Validation loss: 0.5751574018597603, Validation accuracy: 0.8112
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5109683022499084, Validation loss: 0.6162123185396194, Validation accuracy: 0.7929
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.5023467887043953, Validation loss: 0.5316135859489441, Validation accuracy: 0.822


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.47839241528511045, Validation loss: 0.5236048829555512, Validation accuracy: 0.8223
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.47997385007143023, Validation loss: 0.5179517820477486, Validation accuracy: 0.8264


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.4728778330385685, Validation loss: 0.5194916918873786, Validation accuracy: 0.8259
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.4922680247426033, Validation loss: 0.5278846594691277, Validation accuracy: 0.8223


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.4857745749950409, Validation loss: 0.5273168125748634, Validation accuracy: 0.8241
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.4791398582458496, Validation loss: 0.5754622590541839, Validation accuracy: 0.8207
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.41726872792840003, Validation loss: 0.5026111328601837, Validation accuracy: 0.8371


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.40933638653159143, Validation loss: 0.49965794160962107, Validation accuracy: 0.8367
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.4233989505469799, Validation loss: 0.5073805648088455, Validation accuracy: 0.8336


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.4125578469336033, Validation loss: 0.5039419792592525, Validation accuracy: 0.8368
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.42090491235256194, Validation loss: 0.5068077039718628, Validation accuracy: 0.8326


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.4196670994758606, Validation loss: 0.5089886230230332, Validation accuracy: 0.8316
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.4376602989435196, Validation loss: 0.588574442267418, Validation accuracy: 0.817
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.41306345582008364, Validation loss: 0.5303925812244416, Validation accuracy: 0.8309


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.4030250756144524, Validation loss: 0.521735480427742, Validation accuracy: 0.8379
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.40429456943273545, Validation loss: 0.5186765591800213, Validation accuracy: 0.838


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3978156173825264, Validation loss: 0.5129522998631001, Validation accuracy: 0.8395
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.4112763793170452, Validation loss: 0.5250549805164337, Validation accuracy: 0.8348


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3990514058470726, Validation loss: 0.5161637668311596, Validation accuracy: 0.8385
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.40583752912282944, Validation loss: 0.5031927613914013, Validation accuracy: 0.8406
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.37145467185974124, Validation loss: 0.46341968819499013, Validation accuracy: 0.8545


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.35646157470345496, Validation loss: 0.45916716307401656, Validation accuracy: 0.8515
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.36534228903055194, Validation loss: 0.45639585569500923, Validation accuracy: 0.8543


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3541098207831383, Validation loss: 0.4475783056020737, Validation accuracy: 0.857
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.36767544919252393, Validation loss: 0.46553144961595533, Validation accuracy: 0.8516


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3626493310928345, Validation loss: 0.4632744871079922, Validation accuracy: 0.8516
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3724780757725239, Validation loss: 0.6814893960952759, Validation accuracy: 0.7989
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.35738255697488786, Validation loss: 0.46017508670687673, Validation accuracy: 0.8487


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3459577389061451, Validation loss: 0.45221948623657227, Validation accuracy: 0.8513
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3534571953415871, Validation loss: 0.470098665356636, Validation accuracy: 0.8487


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3413107775449753, Validation loss: 0.4518603484332562, Validation accuracy: 0.8541
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3438933089375496, Validation loss: 0.455361440628767, Validation accuracy: 0.8515


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3393090335130692, Validation loss: 0.45252792701125144, Validation accuracy: 0.8529
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3555487354397774, Validation loss: 0.4626592263579369, Validation accuracy: 0.8495
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.30800045388937, Validation loss: 0.4203898465633392, Validation accuracy: 0.8595


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3065453920662403, Validation loss: 0.41534495905041696, Validation accuracy: 0.8621
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.30468444794416427, Validation loss: 0.4162593047320843, Validation accuracy: 0.8626


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.30124459351599214, Validation loss: 0.41241106256842613, Validation accuracy: 0.8646
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.308681168794632, Validation loss: 0.4071225342154503, Validation accuracy: 0.8652


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.29670395922660825, Validation loss: 0.4076958039402962, Validation accuracy: 0.8653
Choosen edge rank_1
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.33548324038088323, Validation loss: 0.5062292969226837, Validation accuracy: 0.8457
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3106075311601162, Validation loss: 0.44409960225224493, Validation accuracy: 0.8601


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.30503088623285296, Validation loss: 0.4434545686841011, Validation accuracy: 0.8611
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.30795639362931254, Validation loss: 0.4353890089690685, Validation accuracy: 0.8629


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.30011931197345254, Validation loss: 0.4312777100503445, Validation accuracy: 0.8638
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.3151921480000019, Validation loss: 0.44265069380402566, Validation accuracy: 0.8609


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.3024016933143139, Validation loss: 0.44268964529037474, Validation accuracy: 0.8621
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.31791389006376264, Validation loss: 0.5023177985846996, Validation accuracy: 0.8513
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.295259146630764, Validation loss: 0.45766021937131884, Validation accuracy: 0.8644


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2838099153190851, Validation loss: 0.454972727149725, Validation accuracy: 0.8644
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2924093334376812, Validation loss: 0.46060455322265625, Validation accuracy: 0.8632


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2887138605564833, Validation loss: 0.45383558750152586, Validation accuracy: 0.8631
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2890679759681225, Validation loss: 0.45404143899679184, Validation accuracy: 0.8674


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.28249473416805265, Validation loss: 0.4483290806412697, Validation accuracy: 0.8675
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.30088129471242425, Validation loss: 0.5126364624500275, Validation accuracy: 0.8455
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2765637077987194, Validation loss: 0.42272599548101425, Validation accuracy: 0.8688


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2666162504255772, Validation loss: 0.41772652238607405, Validation accuracy: 0.8695
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.26493030083179475, Validation loss: 0.4155800598859787, Validation accuracy: 0.8722


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2580169986188412, Validation loss: 0.41821491211652756, Validation accuracy: 0.8695
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.27275490847229955, Validation loss: 0.42472244128584863, Validation accuracy: 0.8669


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.26284797269105914, Validation loss: 0.42270781368017196, Validation accuracy: 0.8693
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.28262422144412996, Validation loss: 0.4994732840359211, Validation accuracy: 0.8474
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.25395096351206303, Validation loss: 0.40948469400405885, Validation accuracy: 0.8723


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2431589561253786, Validation loss: 0.40099884003400804, Validation accuracy: 0.8745
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.25899372538924215, Validation loss: 0.4026421268284321, Validation accuracy: 0.8731


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.24727987399697304, Validation loss: 0.4096686518192291, Validation accuracy: 0.8711
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.25685044038295746, Validation loss: 0.4093365804851055, Validation accuracy: 0.8687


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2504638122916222, Validation loss: 0.40855365693569184, Validation accuracy: 0.8702
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2694826360195875, Validation loss: 0.5398930075764656, Validation accuracy: 0.8436
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.26578021804988383, Validation loss: 0.4388047182559967, Validation accuracy: 0.865


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2496968587040901, Validation loss: 0.43292268574237824, Validation accuracy: 0.8665
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.26081255662441255, Validation loss: 0.43095501124858854, Validation accuracy: 0.8684


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.24794310113787651, Validation loss: 0.4295075176656246, Validation accuracy: 0.8684
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.25235229057073594, Validation loss: 0.42810861811041834, Validation accuracy: 0.8679


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.24401306055486202, Validation loss: 0.4249296763539314, Validation accuracy: 0.869
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2581662288606167, Validation loss: 0.567034418284893, Validation accuracy: 0.8476
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.23806688760221004, Validation loss: 0.4222394202649593, Validation accuracy: 0.8778


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.22442871885001658, Validation loss: 0.41489825963974, Validation accuracy: 0.8802
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.24562742120027542, Validation loss: 0.43564962342381475, Validation accuracy: 0.8754


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.23117063973844051, Validation loss: 0.43302985429763796, Validation accuracy: 0.8766
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.23706924951076508, Validation loss: 0.42730934008955956, Validation accuracy: 0.8802


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.23150076450407506, Validation loss: 0.4323053242266178, Validation accuracy: 0.8791
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.24730676819384098, Validation loss: 0.4701781417429447, Validation accuracy: 0.8673
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.22165905672311784, Validation loss: 0.4105902126431465, Validation accuracy: 0.8836


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.2110967582166195, Validation loss: 0.4089216794073582, Validation accuracy: 0.8827
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.22185096507519483, Validation loss: 0.40351227402687073, Validation accuracy: 0.8817


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.21158842489123345, Validation loss: 0.40341883540153506, Validation accuracy: 0.8833
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.22263864095509053, Validation loss: 0.4105032442510128, Validation accuracy: 0.8842


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.21469976833462714, Validation loss: 0.40292874291539194, Validation accuracy: 0.8846
Choosen edge rank_2
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.23725106381624936, Validation loss: 0.5072293072938919, Validation accuracy: 0.857
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.20528072893619537, Validation loss: 0.42509105920791623, Validation accuracy: 0.8751


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1962834372073412, Validation loss: 0.4197854407131672, Validation accuracy: 0.8785
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19740033693611622, Validation loss: 0.42314294442534445, Validation accuracy: 0.879


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1958007752969861, Validation loss: 0.4212666961550713, Validation accuracy: 0.8805
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.20386504153907298, Validation loss: 0.4290062390267849, Validation accuracy: 0.8763


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1980789491608739, Validation loss: 0.4207736806571484, Validation accuracy: 0.8783
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.22568464323878287, Validation loss: 0.4645448182523251, Validation accuracy: 0.8735
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1970712310373783, Validation loss: 0.43299638718366623, Validation accuracy: 0.8816


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1900639801621437, Validation loss: 0.43058009192347524, Validation accuracy: 0.8821
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19813954113423823, Validation loss: 0.4314357681572437, Validation accuracy: 0.8827


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.19291078585386276, Validation loss: 0.42997512966394424, Validation accuracy: 0.882
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19912588364630937, Validation loss: 0.4329669041931629, Validation accuracy: 0.8843


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.18969801554083823, Validation loss: 0.4328139811754227, Validation accuracy: 0.8821
Choosen edge rank_1
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.21869876463711263, Validation loss: 0.4692215085029602, Validation accuracy: 0.8662
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19589638872444629, Validation loss: 0.40027024611830714, Validation accuracy: 0.8848


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.18393063245713712, Validation loss: 0.39546777367591857, Validation accuracy: 0.8868
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19805801059305667, Validation loss: 0.4087880206108093, Validation accuracy: 0.8829


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.19075047279894353, Validation loss: 0.404836610853672, Validation accuracy: 0.8838
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.20330681225657463, Validation loss: 0.4109653590619564, Validation accuracy: 0.8795


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.18980119992792607, Validation loss: 0.4008210791647434, Validation accuracy: 0.8838
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.20201904665678738, Validation loss: 0.5061863034963607, Validation accuracy: 0.8621
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.17864315913617612, Validation loss: 0.4240577232837677, Validation accuracy: 0.8805


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1708584567606449, Validation loss: 0.4237472194433212, Validation accuracy: 0.8822
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1818846944272518, Validation loss: 0.426510875672102, Validation accuracy: 0.8825


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.17442206187546253, Validation loss: 0.4288614185154438, Validation accuracy: 0.8793
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.18329650741070508, Validation loss: 0.43070982858538626, Validation accuracy: 0.8814


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.17748547133803366, Validation loss: 0.42565425127744677, Validation accuracy: 0.8822
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.2011210091561079, Validation loss: 0.49919515490531924, Validation accuracy: 0.872
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.18648391828685998, Validation loss: 0.45230703346431256, Validation accuracy: 0.8809


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1831309179291129, Validation loss: 0.4420829301327467, Validation accuracy: 0.8826
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1810247775837779, Validation loss: 0.44478196889162064, Validation accuracy: 0.8802


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.18067995150387287, Validation loss: 0.44100848972797396, Validation accuracy: 0.8831
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1825134417489171, Validation loss: 0.44314616821706293, Validation accuracy: 0.8816


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.173338319465518, Validation loss: 0.44550136029720305, Validation accuracy: 0.8827
Choosen edge rank_1
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.19305591478198766, Validation loss: 0.46197239339351653, Validation accuracy: 0.8821
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.16385739928483964, Validation loss: 0.44307327538728714, Validation accuracy: 0.8838


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1547848326936364, Validation loss: 0.4361730945110321, Validation accuracy: 0.8862
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1577774671241641, Validation loss: 0.4403688836097717, Validation accuracy: 0.8857


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.154317772179842, Validation loss: 0.43785881280899047, Validation accuracy: 0.8872
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.15916369576007128, Validation loss: 0.44639848709106444, Validation accuracy: 0.8844


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.15699486392736434, Validation loss: 0.44293599396944044, Validation accuracy: 0.885
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.18043838699162007, Validation loss: 0.47951348558068274, Validation accuracy: 0.8764
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.17654663522541522, Validation loss: 0.429829980507493, Validation accuracy: 0.8851


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1672952186241746, Validation loss: 0.4287807346135378, Validation accuracy: 0.8858
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.17570356340706347, Validation loss: 0.43502006962895395, Validation accuracy: 0.8841


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.17157034186273815, Validation loss: 0.43391028814017774, Validation accuracy: 0.8838
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1765448020249605, Validation loss: 0.43722173795104025, Validation accuracy: 0.8829


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1652384246662259, Validation loss: 0.4343578389286995, Validation accuracy: 0.8845
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.18128150887042285, Validation loss: 0.4516262601315975, Validation accuracy: 0.8792
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.158264209613204, Validation loss: 0.4341392070055008, Validation accuracy: 0.8873


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1518099803328514, Validation loss: 0.43148273795843123, Validation accuracy: 0.886
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1574802475348115, Validation loss: 0.43205857276916504, Validation accuracy: 0.8856


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.15225566447526218, Validation loss: 0.4296747103333473, Validation accuracy: 0.8852
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.15678806344419718, Validation loss: 0.4332724459469318, Validation accuracy: 0.8854


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.15332570543885232, Validation loss: 0.42678984358906746, Validation accuracy: 0.8873
Choosen edge rank_1
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.16873611684143544, Validation loss: 0.5388621020317078, Validation accuracy: 0.8662
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.15472217074036598, Validation loss: 0.46077552169561387, Validation accuracy: 0.8786


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.14696974440664054, Validation loss: 0.45770198792219163, Validation accuracy: 0.8801
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.15774144750088454, Validation loss: 0.45929177701473234, Validation accuracy: 0.8824


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.14682738230377435, Validation loss: 0.4547404579818249, Validation accuracy: 0.8827
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1493074809461832, Validation loss: 0.4483408710360527, Validation accuracy: 0.8818


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.14325025468692185, Validation loss: 0.45092802971601487, Validation accuracy: 0.8821
Choosen edge rank_0
Train full model


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.16433309227228166, Validation loss: 0.4742787730693817, Validation accuracy: 0.8773
Train edge rank_2


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.13388096334040164, Validation loss: 0.44090461447834967, Validation accuracy: 0.8871


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.1308405160009861, Validation loss: 0.4240249159932137, Validation accuracy: 0.8892
Train edge rank_1


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.13382656675577165, Validation loss: 0.43445589184761046, Validation accuracy: 0.8871


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.13076744908466936, Validation loss: 0.43075027465820315, Validation accuracy: 0.8878
Train edge rank_0


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Train loss: 0.1360051382854581, Validation loss: 0.4308974747359753, Validation accuracy: 0.8884


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, Train loss: 0.12968290810286998, Validation loss: 0.4222817727923393, Validation accuracy: 0.8907
Choosen edge rank_2


In [24]:
torch.save(model, "./models/greedy-tn-intermediate.pt")

In [12]:
model = torch.load("./models/greedy-tn-intermediate.pt")

In [13]:
for core in model[-1][0].construct_network():
    print(core.shape)

(128, 25)
(8, 5)
(8, 3)
(25, 5, 3, 10)


In [27]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1, verbose=True)

for param in model.parameters():
    param.requires_grad = True

accuracies = train(model, train_dataloader, val_dataloader, criterion, optimizer, device, 70, scheduler, plot=False)

Adjusting learning rate of group 0 to 1.0000e-03.


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 0, Train loss: 0.10069047795608639, Validation loss: 0.42230992525815964, Validation accuracy: 0.9006


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 1, Train loss: 0.08073057451471687, Validation loss: 0.46170606195926667, Validation accuracy: 0.9023


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 2, Train loss: 0.07556636963412165, Validation loss: 0.46502039112150667, Validation accuracy: 0.9023


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 3, Train loss: 0.07074467816762627, Validation loss: 0.4775850810110569, Validation accuracy: 0.9027


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 4, Train loss: 0.06409267967194318, Validation loss: 0.4906503976136446, Validation accuracy: 0.9021


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 5, Train loss: 0.06360460294038058, Validation loss: 0.496971602961421, Validation accuracy: 0.9026


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 6, Train loss: 0.059969392183236776, Validation loss: 0.5210489025712013, Validation accuracy: 0.9012


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 7, Train loss: 0.057404009765945375, Validation loss: 0.5233419132232666, Validation accuracy: 0.904


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 8, Train loss: 0.05471053035929799, Validation loss: 0.5300352012366056, Validation accuracy: 0.9049


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 9, Train loss: 0.051796141415834426, Validation loss: 0.5445689228177071, Validation accuracy: 0.9008


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 10, Train loss: 0.05276472736056894, Validation loss: 0.5433148363232613, Validation accuracy: 0.9049


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 11, Train loss: 0.05095300630480051, Validation loss: 0.5551462109386921, Validation accuracy: 0.9043


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 12, Train loss: 0.04939734011422843, Validation loss: 0.5568199402838946, Validation accuracy: 0.9028


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 13, Train loss: 0.04842793652985711, Validation loss: 0.5621585554629565, Validation accuracy: 0.9031


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 14, Train loss: 0.045191957585979256, Validation loss: 0.5791822398453951, Validation accuracy: 0.902


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 15, Train loss: 0.04497393043315969, Validation loss: 0.5998119181394577, Validation accuracy: 0.9032


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 16, Train loss: 0.045458667180035266, Validation loss: 0.5909288343787193, Validation accuracy: 0.9026


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 17, Train loss: 0.04364395570661873, Validation loss: 0.6045667254179716, Validation accuracy: 0.9019


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 18, Train loss: 0.039280964521691206, Validation loss: 0.5982035246491432, Validation accuracy: 0.904


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 19, Train loss: 0.04000708712823689, Validation loss: 0.6245280812680721, Validation accuracy: 0.9024


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 20, Train loss: 0.03893752221995965, Validation loss: 0.6071665266156197, Validation accuracy: 0.9065


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 21, Train loss: 0.038295597424497826, Validation loss: 0.6010682212561369, Validation accuracy: 0.905


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 22, Train loss: 0.03805767530342564, Validation loss: 0.6188604008406401, Validation accuracy: 0.9038


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 23, Train loss: 0.03759860103297979, Validation loss: 0.6388199777156115, Validation accuracy: 0.9058


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 24, Train loss: 0.03424513716436923, Validation loss: 0.657914225384593, Validation accuracy: 0.9043


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 25, Train loss: 0.0359330569668673, Validation loss: 0.6568238861113787, Validation accuracy: 0.9046


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 26, Train loss: 0.03496207093005069, Validation loss: 0.6498412424325943, Validation accuracy: 0.905


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 27, Train loss: 0.034422056196955965, Validation loss: 0.6584928761422634, Validation accuracy: 0.9041


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.
Epoch: 28, Train loss: 0.03395587487635203, Validation loss: 0.69240806594491, Validation accuracy: 0.9029


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 29, Train loss: 0.030883089649956674, Validation loss: 0.6681791953742504, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 30, Train loss: 0.029920726340380496, Validation loss: 0.6643769615888595, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 31, Train loss: 0.027805998762138188, Validation loss: 0.6735403176397086, Validation accuracy: 0.907


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 32, Train loss: 0.027176288814283907, Validation loss: 0.670934379696846, Validation accuracy: 0.907


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 33, Train loss: 0.026234560909681022, Validation loss: 0.677943958863616, Validation accuracy: 0.9059


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 34, Train loss: 0.026559334468794986, Validation loss: 0.6703467109799385, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 35, Train loss: 0.027089077822165564, Validation loss: 0.6815269930660725, Validation accuracy: 0.9067


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 36, Train loss: 0.02601719818986021, Validation loss: 0.6760123302042484, Validation accuracy: 0.9067


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 37, Train loss: 0.02696785794396419, Validation loss: 0.676430288925767, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 38, Train loss: 0.024375306863803417, Validation loss: 0.6768294544517994, Validation accuracy: 0.9068


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 39, Train loss: 0.02526845447998494, Validation loss: 0.680079535022378, Validation accuracy: 0.9064


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 40, Train loss: 0.024356336403870957, Validation loss: 0.6809535840898753, Validation accuracy: 0.9071


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 41, Train loss: 0.022844774030381813, Validation loss: 0.6897170884907246, Validation accuracy: 0.9055


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 42, Train loss: 0.024138390675070696, Validation loss: 0.6863423530012369, Validation accuracy: 0.9063


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 43, Train loss: 0.023863326172111555, Validation loss: 0.6906226281821728, Validation accuracy: 0.9055


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 44, Train loss: 0.02322769441397395, Validation loss: 0.6885126460343599, Validation accuracy: 0.907


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 45, Train loss: 0.023326197757269254, Validation loss: 0.6875695312023162, Validation accuracy: 0.9067


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 46, Train loss: 0.022670422890107147, Validation loss: 0.6924616804718972, Validation accuracy: 0.9084


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 47, Train loss: 0.02368619643279817, Validation loss: 0.6962670183181763, Validation accuracy: 0.9068


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 48, Train loss: 0.02307517756568268, Validation loss: 0.7042686600238085, Validation accuracy: 0.907


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 49, Train loss: 0.022260090406052768, Validation loss: 0.7021327286213637, Validation accuracy: 0.907


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 50, Train loss: 0.023087270073941908, Validation loss: 0.7035068422555923, Validation accuracy: 0.9061


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 51, Train loss: 0.022941621282836423, Validation loss: 0.7085012891143561, Validation accuracy: 0.9063


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 52, Train loss: 0.02336596248066053, Validation loss: 0.7051025847345591, Validation accuracy: 0.9072


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 53, Train loss: 0.02424861330818385, Validation loss: 0.6985810046643018, Validation accuracy: 0.9064


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 54, Train loss: 0.022042129133827984, Validation loss: 0.7086646569520235, Validation accuracy: 0.9063


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 55, Train loss: 0.021825314962188713, Validation loss: 0.7147218849509954, Validation accuracy: 0.9069


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 56, Train loss: 0.02268018133181613, Validation loss: 0.71470344170928, Validation accuracy: 0.906


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 57, Train loss: 0.02228657011559699, Validation loss: 0.7088157998025417, Validation accuracy: 0.9079


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 58, Train loss: 0.02112864910485223, Validation loss: 0.7129875706136226, Validation accuracy: 0.9067


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 59, Train loss: 0.021186710735666567, Validation loss: 0.7113412816822529, Validation accuracy: 0.9075


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 60, Train loss: 0.021520086824020835, Validation loss: 0.7151052296161652, Validation accuracy: 0.9073


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 61, Train loss: 0.021303994254092688, Validation loss: 0.7130614249408245, Validation accuracy: 0.9077


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 62, Train loss: 0.023317679733154364, Validation loss: 0.7151071012020112, Validation accuracy: 0.9072


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 63, Train loss: 0.022897770242765545, Validation loss: 0.7167707352340221, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 64, Train loss: 0.022726821080665104, Validation loss: 0.7112627300620079, Validation accuracy: 0.9071


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 65, Train loss: 0.020932092103699687, Validation loss: 0.7169556189328432, Validation accuracy: 0.9074


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 66, Train loss: 0.02161267951654736, Validation loss: 0.718777424544096, Validation accuracy: 0.9073


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 67, Train loss: 0.021761620508390478, Validation loss: 0.7134135715663433, Validation accuracy: 0.9069


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 68, Train loss: 0.022366122332401574, Validation loss: 0.712723551467061, Validation accuracy: 0.9066


  0%|          | 0/500 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-05.
Epoch: 69, Train loss: 0.02334692127339076, Validation loss: 0.7121564545482397, Validation accuracy: 0.9071


In [None]:
all_losses, predicted_labels, true_labels = predict(model, val_dataloader, criterion, device)
assert len(predicted_labels) == len(val_dataset)
accuracy = accuracy_score(predicted_labels.to("cpu"), true_labels.to("cpu"))
print("tests passed")

In [None]:
torch.save({
    "model": model.state_dict(),
}, "./models/conv-trl-2.1-state-dict.pt")