In [None]:
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd
import os
import time
import itertools
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from tqdm import tqdm
from torch.utils.data import Dataset
from scipy.spatial.distance import cdist
from torchvision import transforms
from sklearn.model_selection import train_test_split

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda:0


In [None]:
size = 2
batch_size = 1
epochs = 100

In [None]:
"""
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        read_img = cv.imread(img_path,cv.IMREAD_GRAYSCALE)
        image = cv.resize(read_img,(size,size))
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
"""

In [None]:
"""
path_data_train = "/content/drive/MyDrive/Dataset_MRI_Brain/Data_training"
path_csv_train = "/content/drive/MyDrive/Dataset_MRI_Brain/Training.csv"
path_data_test = "/content/drive/MyDrive/Dataset_MRI_Brain/Data_testing"
path_csv_test = "/content/drive/MyDrive/Dataset_MRI_Brain/Testing.csv"

train_dataset = CustomImageDataset(csv_file = path_csv_train , img_dir = path_data_train, transform = transforms.ToTensor(),
                               target_transform = transforms.Compose([
                                  lambda x:torch.LongTensor([x]),
                                  #lambda x:F.one_hot(x, 4)
                                 ]))

test_dataset = CustomImageDataset(csv_file = path_csv_test , img_dir = path_data_test, transform = transforms.ToTensor(),
                               target_transform = transforms.Compose([
                                  lambda x:torch.LongTensor([x]),
                                  #lambda x:F.one_hot(x, 4)
                                 ]))
"""

In [None]:
train_dataset = torchvision.datasets.MNIST('/files/', train=True, download=True, transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((size, size))]),
                              target_transform = transforms.Compose([
                                lambda x:torch.LongTensor([x]),
                                ]))
test_dataset = torchvision.datasets.MNIST('/files/', train=False, download=True, transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((size, size))]),
                              target_transform = transforms.Compose([
                                lambda x:torch.LongTensor([x]),
                               ]))

In [None]:
#train_indices, val_indices, _, _ = train_test_split(
#    range(len(train_dataset)),
#    train_dataset.targets,
#    stratify=train_dataset.targets,
#    test_size=0.1,
#)
#train_split = torch.utils.data.Subset(train_dataset, train_indices)
#val_split = torch.utils.data.Subset(train_dataset, val_indices)

train_split, val_split = torch.utils.data.random_split(train_dataset, [45000, 15000])

train_loader = torch.utils.data.DataLoader(train_split, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_split, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(len(train_loader), len(val_loader), len(test_loader))

In [None]:
col, row = np.meshgrid(np.arange(size), np.arange(size))
coord = np.stack((col, row), axis = 2).reshape(-1, 2)
distance = torch.from_numpy(cdist(coord, coord)).float().to(device)
sigma_distance = torch.var(distance)

In [None]:
def create_graph_shift_operator(img):
    #grayscale = img.reshape(-1,1)
    #gray_different = torch.from_numpy(cdist(grayscale, grayscale)).float().to(device)
    #sigma_grayscale = torch.var(gray_different)

    adjcency_matrix = torch.exp(-distance**2 / sigma_distance) #* torch.exp(-gray_different**2 / sigma_grayscale)
    adjcency_matrix[adjcency_matrix < 0.01] = 0
    #adjcency_matrix = adjcency_matrix + torch.eye(size * size).to(device)

    degree_matrix = torch.diag(torch.sum(adjcency_matrix, axis=1))
    degree_matrix_hat = torch.sqrt(torch.linalg.inv(degree_matrix))

    graph_shift_operator = torch.matmul(degree_matrix_hat, torch.matmul(adjcency_matrix, degree_matrix_hat))
    
    return graph_shift_operator

In [None]:
class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.projection = nn.Linear(in_features, out_features, bias=False)

    def forward(self, input, graph_shift_operator):
        batch_size = input.size(0)
        support = self.projection(input)
        # output = torch.mm(graph_shift_operator, support)
        output = torch.stack([torch.mm(graph_shift_operator[b], support[b]) for b in range(batch_size)])
        return output

In [None]:
class GCN(nn.Module):
    def __init__(self, img_size, nfeat, nclass, dropout):
        super(GCN, self).__init__()
        self.N = img_size ** 2
        self.gc1 = GraphConv(nfeat, 4)
        self.gc2 = GraphConv(4, 128)
        self.gc3 = GraphConv(128, 512)
        self.gc4 = GraphConv(512, 64)
        self.gc5 = GraphConv(64, 24)
        self.gc6 = GraphConv(24, 12)
        self.gc7 = GraphConv(12, 8)
        self.gc8 = GraphConv(8, 4)
        self.gc9 = GraphConv(4, 1)
        self.dropout = dropout
        self.fc = nn.Linear(self.N, nclass, bias=False)
        self.init_weights()

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, graph_shift_operator):
        batch_size = x.size(0)
        x = F.relu(self.gc1(x, graph_shift_operator))
        x = F.relu(self.gc2(x, graph_shift_operator))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc3(x, graph_shift_operator))
        x = F.relu(self.gc4(x, graph_shift_operator))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc5(x, graph_shift_operator))
        x = F.relu(self.gc6(x, graph_shift_operator))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc7(x, graph_shift_operator))
        x = F.relu(self.gc8(x, graph_shift_operator))
        x = self.gc9(x, graph_shift_operator)
        
        
        
        x = self.fc(x.reshape(batch_size, 1, -1))
        return x

    def loss(self, pred, label):
        loss = nn.CrossEntropyLoss()
        return loss(pred, label)

In [None]:
model = GCN(img_size=size, nfeat=1, nclass=10, dropout=0.5)
model.to(device)

GCN(
  (gc1): GraphConv(
    (projection): Linear(in_features=1, out_features=4, bias=False)
  )
  (gc2): GraphConv(
    (projection): Linear(in_features=4, out_features=128, bias=False)
  )
  (gc3): GraphConv(
    (projection): Linear(in_features=128, out_features=512, bias=False)
  )
  (gc4): GraphConv(
    (projection): Linear(in_features=512, out_features=64, bias=False)
  )
  (gc5): GraphConv(
    (projection): Linear(in_features=64, out_features=24, bias=False)
  )
  (gc6): GraphConv(
    (projection): Linear(in_features=24, out_features=12, bias=False)
  )
  (gc7): GraphConv(
    (projection): Linear(in_features=12, out_features=8, bias=False)
  )
  (gc8): GraphConv(
    (projection): Linear(in_features=8, out_features=4, bias=False)
  )
  (gc9): GraphConv(
    (projection): Linear(in_features=4, out_features=1, bias=False)
  )
  (fc): Linear(in_features=4, out_features=10, bias=False)
)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_losses, val_losses = [], []
for epoch in range(epochs):
    train_loss = 0.0
    start = time.time()
    model.train()
    for batch_id, (data, label) in enumerate(tqdm(train_loader)):
        batch_size = data.size(0)
        print(data.size())
        graph_shift_operator = torch.stack([create_graph_shift_operator(data[b].squeeze().detach().numpy()) for b in range(batch_size)])
        data, label = data.to(device), label.to(device)
        data = data.reshape(batch_size, -1, 1)
        optimizer.zero_grad()
        pred = model(data, graph_shift_operator)
        loss = model.loss(pred.reshape(-1, pred.shape[-1]), label.reshape(-1))
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / len(train_loader)
    train_losses.append(train_loss)

    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch_id, (data, label) in enumerate(tqdm(val_loader)):
            batch_size = data.size(0)
            graph_shift_operator = torch.stack([create_graph_shift_operator(data[b].squeeze().detach().numpy()) for b in range(batch_size)])
            data, label = data.to(device), label.to(device)
            data = data.reshape(batch_size, -1, 1)
            pred = model(data, graph_shift_operator)
            loss = model.loss(pred.reshape(-1, pred.shape[-1]), label.reshape(-1))

            val_loss += loss.item() / len(val_loader)

        val_losses.append(val_loss)

        end = time.time()

        print('Epoch: [{}/{}] - Train loss: {:.6f} - Val loss: {:.6f} - Time: {:.6f}'.format(epoch + 1, epochs, train_loss, val_loss, end - start))

  0%|          | 0/45000 [00:00<?, ?it/s]

torch.Size([1, 1, 2, 2])


  0%|          | 20/45000 [00:03<1:41:26,  7.39it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 60/45000 [00:03<25:29, 29.38it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 96/45000 [00:04<13:11, 56.73it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 133/45000 [00:04<09:00, 83.03it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 149/45000 [00:04<08:35, 87.05it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 176/45000 [00:04<07:51, 95.14it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  0%|          | 203/45000 [00:05<06:58, 107.10it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  0%|          | 216/45000 [00:05<06:49, 109.36it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 241/45000 [00:05<07:10, 103.96it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 266/45000 [00:05<06:50, 108.85it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 289/45000 [00:05<07:15, 102.61it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 312/45000 [00:06<06:59, 106.54it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 323/45000 [00:06<07:46, 95.70it/s] 

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 344/45000 [00:06<07:55, 93.88it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 366/45000 [00:06<07:32, 98.66it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 389/45000 [00:06<07:02, 105.56it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 400/45000 [00:07<08:52, 83.73it/s] 

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 421/45000 [00:07<08:27, 87.89it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 441/45000 [00:07<08:34, 86.60it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 462/45000 [00:07<08:31, 87.01it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  1%|          | 472/45000 [00:07<08:55, 83.21it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 493/45000 [00:08<08:09, 90.87it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 516/45000 [00:08<07:42, 96.12it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 536/45000 [00:08<08:45, 84.67it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|          | 545/45000 [00:08<08:47, 84.30it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 563/45000 [00:08<09:04, 81.57it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 585/45000 [00:09<08:12, 90.15it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 606/45000 [00:09<07:48, 94.68it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 626/45000 [00:09<08:46, 84.25it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 644/45000 [00:09<09:37, 76.80it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  1%|▏         | 664/45000 [00:10<08:30, 86.84it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 687/45000 [00:10<07:27, 98.97it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 698/45000 [00:10<07:15, 101.68it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 719/45000 [00:10<08:14, 89.59it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 742/45000 [00:10<07:32, 97.84it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 762/45000 [00:11<07:35, 97.21it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 785/45000 [00:11<07:00, 105.23it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  2%|▏         | 807/45000 [00:11<07:12, 102.23it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 830/45000 [00:11<07:07, 103.31it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 852/45000 [00:12<07:12, 102.03it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 875/45000 [00:12<06:52, 106.89it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 897/45000 [00:12<07:12, 101.92it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 908/45000 [00:12<07:13, 101.61it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 932/45000 [00:12<07:03, 103.99it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 954/45000 [00:13<07:20, 100.05it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 977/45000 [00:13<07:09, 102.39it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 990/45000 [00:13<06:47, 108.11it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 1012/45000 [00:13<08:05, 90.61it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  2%|▏         | 1037/45000 [00:13<06:57, 105.28it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 1063/45000 [00:14<06:32, 112.00it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 1089/45000 [00:14<06:48, 107.44it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 1101/45000 [00:14<07:25, 98.51it/s] 

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  2%|▏         | 1122/45000 [00:14<08:12, 89.07it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1132/45000 [00:14<08:58, 81.51it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1149/45000 [00:15<09:19, 78.40it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1165/45000 [00:15<09:46, 74.69it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1181/45000 [00:15<09:39, 75.60it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  3%|▎         | 1189/45000 [00:15<10:22, 70.34it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])

  3%|▎         | 1204/45000 [00:15<11:49, 61.72it/s]


torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1219/45000 [00:16<11:36, 62.87it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1226/45000 [00:16<12:08, 60.10it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1239/45000 [00:16<13:43, 53.12it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])


  3%|▎         | 1255/45000 [00:16<09:48, 74.38it/s]

torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])
torch.Size([1, 1, 2, 2])





KeyboardInterrupt: ignored

In [None]:
plt.plot(train_losses, label='Train loss')
plt.plot(val_losses, label='Val loss')
plt.legend()
plt.show()

In [None]:
confusion_matrix = np.zeros((10, 10))

with torch.no_grad():
    for batch_id, (data, label) in enumerate(tqdm(test_loader)):
        batch_size = data.size(0)
        graph_shift_operator = torch.stack([create_graph_shift_operator(data[b].squeeze().detach().numpy()) for b in range(batch_size)])
        data, label = data.to(device), label.to(device)
        

        data = data.reshape(batch_size, -1, 1)
        outputs = model(data, graph_shift_operator)

        preds = torch.argmax(outputs, -1)
        
        for t, p in zip(label.reshape(-1), preds.reshape(-1)):
            confusion_matrix[t.long(), p.long()] += 1


In [None]:

print('Per-class accuracy', 100 * np.diag(confusion_matrix) / confusion_matrix.sum(1))
print('Accuracy:', 100 * np.diag(confusion_matrix).sum() / confusion_matrix.sum())

normalized_confusion_matrix = confusion_matrix / confusion_matrix.sum(axis = 1, keepdims = True)

ax = sns.heatmap(normalized_confusion_matrix, fmt='')
ax.set_title('Confusion Matrix');
ax.set_xlabel('Predicted Values')
ax.set_ylabel('Actual Values ');
plt.show()

In [None]:
# torch.save(model.state_dict(), '/content/drive/MyDrive/model_GCN_v2.pt')