# Import packages e define functions

In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import random
import pickle
import datetime
import os
import torch.nn.functional as F
from scipy.stats import special_ortho_group
from scipy.spatial.transform import Rotation as R
import torch.nn.functional as F
from pytorch_metric_learning import miners, losses
import wandb


In [2]:
seed=999
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
def CreateCouples(pt):
    
    """ This function, modifies the shape of the tensor to fit the model 

    """
    clusters= pt.shape[0]
    couples = []
    # for each subcluster
    for i in tqdm(range(0,clusters)):
        
        # discover the number of fragments
        n_frags = pt[i][0].shape[0]

        # exract the adj matrix
        matr= pt[i][0]

        # exract the cluster of fragments
        data = pt[i][1]
        
        for j in range(0,n_frags -1):
            
            init = j+1
            for k in range(init,n_frags): 

             couples.append([data[j], data[k], matr[j][k]])
    return couples 

def CreateCouples_cluster(pt):
    """ This function, modifies the shape of the tensor to fit the model """
    clusters = pt.shape[0]
    cluster_list = []

    for i in tqdm(range(clusters)):
        n_frags = pt[i][0].shape[0]
        matr = pt[i][0]
        data = pt[i][1]
        cluster_couples = []

        for j in range(n_frags - 1):
            init = j + 1
            for k in range(init, n_frags):
                cluster_couples.append([data[j], data[k], matr[j][k]])

        cluster_list.append(cluster_couples)

    return cluster_list
def center_in_origin(frag):
    min_vals, _ = torch.min(frag[:, 0:3], axis=0)
    max_vals, _ = torch.max(frag[:, 0:3], axis=0)
    frag[:, 0:3] = (frag[:, 0:3] - min_vals) / (max_vals - min_vals)
    
    return frag
    
def normalize(batch):
    out=[]
    for element in batch:
        out.append(center_in_origin(element))
    out_tensor = torch.stack(out)    
    return out_tensor  



def use_GPU():
    """ This function activates the gpu 
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(torch.cuda.get_device_name(0), "is available and being used")
    else:
        device = torch.device("cpu")
        print("GPU is not available, using CPU instead") 
    return device  





def translate_to_origin(frag):
    """ This function translate each fragment in the origin
    """
    frag[:,:3] -= torch.mean(frag[:,:3]) 
    return frag

def apply_translation(batch):
    """ This function apply translate_to_origin() to each fragment in the batch
    """
    out=[]
    for element in batch:
        out.append(translate_to_origin(element))
    out_tensor = torch.stack(out)    
    return out_tensor


def random_rotation(frag):

    randrot = (torch.rand(3)*360).tolist()
    r = R.from_euler('zyx', randrot, degrees=True)
    frag[:,:3] = torch.from_numpy(r.apply(frag[:,:3]))
    frag[:,3:6] = torch.from_numpy(r.apply(frag[:,3:6]))
    return frag

def apply_randomrotations(batch):
    """ This function apply random_rotation() to each fragment in the batch
    """
    out=[]
    for element in batch:
        out.append(random_rotation(element))
    out_tensor = torch.stack(out)    
    return out_tensor


class ContrastiveLoss(torch.nn.Module):
    def __init__(self, m=2.0):
        super(ContrastiveLoss, self).__init__()
        self.m = m

    def forward(self, y1, y2, d):
        euc_dist = torch.nn.functional.pairwise_distance(y1, y2)

        if d.dim() == 0:  # Se d è uno scalare
            if d == 0:
                return torch.mean(torch.pow(euc_dist, 2))  # Distanza quadratica
            else:  # d == 1
                delta = self.m - euc_dist
                delta = torch.clamp(delta, min=0.0, max=None)
                return torch.mean(torch.pow(delta, 2))
        else:  # Se d è un tensore di valori 0 e 1
            is_same = d == 0
            is_diff = d == 1

            loss_same = torch.pow(euc_dist[is_same], 2).mean() if torch.any(is_same) else torch.tensor(0.0).to(euc_dist.device)
            loss_diff = torch.pow(torch.clamp(self.m - euc_dist[is_diff], min=0.0), 2).mean() if torch.any(is_diff) else torch.tensor(0.0).to(euc_dist.device)

            return (loss_same + loss_diff) / (1.0 + torch.any(is_same).float() + torch.any(is_diff).float())

In [4]:
# https://github.com/qq456cvb/Point-Transformers

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)


def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        distance = torch.min(distance, dist)
        farthest = torch.max(distance, -1)[1]
    return centroids

def sample_and_group(npoint, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint 
    
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]

    new_xyz = index_points(xyz, fps_idx) 
    new_points = index_points(points, fps_idx)

    dists = square_distance(new_xyz, xyz)  # B x npoint x N
    idx = dists.argsort()[:, :, :nsample]  # B x npoint x K

    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
    return new_xyz, new_points


class Local_op(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        b, n, s, d = x.size()  # torch.Size([32, 512, 32, 6]) 
        x = x.permute(0, 1, 3, 2)
        x = x.reshape(-1, d, s)
        batch_size, _, N = x.size()
        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
        x = torch.max(x, 2)[0]
        x = x.view(batch_size, -1)
        x = x.reshape(b, n, -1).permute(0, 2, 1)
        return x


class SA_Layer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight 
        self.v_conv = nn.Conv1d(channels, channels, 1)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 
        x_k = self.k_conv(x)# b, c, n        
        x_v = self.v_conv(x)
        energy = x_q @ x_k # b, n, n 
        attention = self.softmax(energy)
        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
        x_r = x_v @ attention # b, c, n 
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        x = x + x_r
        return x
    

class StackedAttention(nn.Module):
    def __init__(self, channels=256):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm1d(channels)
        self.bn2 = nn.BatchNorm1d(channels)

        self.sa1 = SA_Layer(channels)
        self.sa2 = SA_Layer(channels)
        self.sa3 = SA_Layer(channels)
        self.sa4 = SA_Layer(channels)

        self.relu = nn.ReLU()
        
    def forward(self, x):
        # 
        # b, 3, npoint, nsample  
        # conv2d 3 -> 128 channels 1, 1
        # b * npoint, c, nsample 
        # permute reshape
        batch_size, _, N = x.size()

        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x)))

        x1 = self.sa1(x)
        x2 = self.sa2(x1)
        x3 = self.sa3(x2)
        x4 = self.sa4(x3)
        
        x = torch.cat((x1, x2, x3, x4), dim=1)

        return x

In [5]:

class Branch(nn.Module):
    def __init__(self):
        super().__init__()
        
        d_points = 7 # we have 7 features for each point
        self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
        self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
        self.pt_last = StackedAttention()

        self.relu = nn.ReLU()
        self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(negative_slope=0.2))

        
    def forward(self, x):
        xyz = x[..., :3]
        x = x.permute(0, 2, 1)
        batch_size, _, _ = x.size()
        x= x.double()
        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
        x = x.permute(0, 2, 1)
        new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)         
        feature_0 = self.gather_local_0(new_feature)
        feature = feature_0.permute(0, 2, 1)
        new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 
        feature_1 = self.gather_local_1(new_feature)
        
        x = self.pt_last(feature_1)
        x = torch.cat([x, feature_1], dim=1)
        x = self.conv_fuse(x)
        x = torch.max(x, 2)[0] # Returns the maximum value of all elements in the input tensor. (2 elementes for each vector)
        x = x.view(batch_size, -1) # Returns a new tensor with the same data as the self tensor but of a different shape.
        
        return x
    
    
class PairModel1(nn.Module):
    def __init__(self):
        super().__init__()
        
        output_channels = 2 # it's a binary classification

        self.branch1 = Branch()
        self.branch2 = Branch()
        self.dp1 = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()
            
        # classificator
        self.linear1 = nn.Linear(2048, 512, bias=False)
        self.bn1 = nn.BatchNorm1d(512)
        self.dp2 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dp3 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(256, output_channels)
        
    def forward(self, batch_1, batch_2):
        
        x_1 = self.branch1(batch_1)
        x_2 = self.branch2(batch_2)
        #print(x_1.shape)
        #print(x_2.shape)
        x_mult = x_1 * x_2 # let's sum the output of the two branches 
        x_sum = x_1 + x_2
        x = torch.cat((x_mult, x_sum), dim=1) 
        #x = self.dp1(x)

        # classificator
        x = self.relu(self.bn1(self.linear1(x)))
        x = self.dp2(x)
        x = self.relu(self.bn2(self.linear2(x)))
        x = self.dp3(x)
        x = self.linear3(x)
        
        return x


In [6]:
# data loading
train = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\train_pair_dataset_REG.pt")
val = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\val_pair_dataset_REG.pt")
test = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\test_pair_dataset_REG.pt")

In [7]:
#let's find the largest clusters

threshold = 70

# Train
indices = [i for i, data in enumerate(train) if data[0].shape[0] > threshold]
count = len(indices)

# Val
indices_val = [i for i, data in enumerate(val) if data[0].shape[0] > threshold]
count_val = len(indices_val)

# Test
indices_test = [i for i, data in enumerate(test) if data[0].shape[0] > threshold]
count_test = len(indices_test)

print("Positions to remove (Train):", indices)
print("Positions to remove (Val):", indices_val)
print("Positions to remove (Test):", indices_test)


# We are removing the largest clusters for computational reasons
mask = torch.ones(train.shape[0], dtype=torch.bool)
mask[indices] = False
filtered_tensor = train[mask]
train = filtered_tensor

mask_val = torch.ones(val.shape[0], dtype=torch.bool)
mask_val[indices_val] = False
filtered_tensor_val = val[mask_val]
val = filtered_tensor_val

mask_test = torch.ones(test.shape[0], dtype=torch.bool)
mask_test[indices_test] = False
filtered_tensor_test = test[mask_test]
test= filtered_tensor_test   

print(train.shape)
print(val.shape)
print(test.shape)

Positions to remove (Train): [2, 14, 20, 22, 28, 30, 31, 35, 36, 39, 71, 78, 87, 91, 95, 109, 136, 148, 163, 167, 192, 197, 205, 209, 215, 217, 218, 222, 228, 241, 255, 261, 263, 265, 273, 284, 300, 308, 318, 337, 344, 346, 353, 361, 364, 369, 398, 420, 423, 427, 431, 433, 468, 469, 471, 474, 489, 494, 496, 507, 513, 516, 517, 538, 547, 551, 564, 577, 591, 595, 597, 600, 607, 613, 617, 633, 635, 644, 651, 652, 660, 664, 675, 689, 694, 699, 700, 707, 711, 721, 731, 735, 754, 757, 758, 763, 768, 771, 776, 783, 789, 813, 818, 831, 834, 837, 841, 844, 846, 874, 879, 896, 909, 917, 931, 943, 971, 977, 979, 988, 991, 994, 996, 999, 1042, 1045, 1050, 1057, 1077, 1078, 1079, 1084, 1100, 1109, 1118, 1126, 1127, 1129, 1130, 1131, 1143, 1150, 1156, 1188, 1217, 1219, 1227, 1254, 1274, 1284, 1290, 1295, 1299, 1316, 1324, 1334, 1336, 1337, 1347, 1353, 1356, 1390, 1392, 1412, 1418, 1460, 1463, 1466, 1469, 1474, 1483, 1489, 1498, 1499, 1504, 1505, 1507, 1514]
Positions to remove (Val): [6, 12, 13, 22,

# First Trial

In [25]:
# Change the shape of the data
train_couples = CreateCouples(train)
val_couples = CreateCouples(val)
#test_couples = CreateCouples(test)

100%|██████████| 1348/1348 [00:00<00:00, 2227.90it/s]
100%|██████████| 292/292 [00:00<00:00, 984.74it/s]


In [24]:
train_couples_c = CreateCouples_cluster(train)
val_couples_c = CreateCouples_cluster(val)

100%|██████████| 1348/1348 [00:00<00:00, 1820.22it/s]
100%|██████████| 292/292 [00:00<00:00, 3025.54it/s]


In [33]:
train_couples_0 = [item for item in train_couples if item[2] == 0]
train_couples_1 = [item for item in train_couples if item[2] == 1]


couples_per_epoch = 10000

In [26]:
train_couples_1

[[array([[-0.82133717,  0.99539159, -0.88014335, ...,  0.89506071,
          -0.35446518,  0.00640204],
         [-0.88240756,  1.0048247 , -0.902944  , ..., -0.35452528,
           0.83211797,  0.02206983],
         [-0.88315378,  0.9996307 , -0.91662904, ..., -0.35452528,
           0.83211797,  0.02206983],
         ...,
         [-0.78929048,  0.94367599, -0.89353457, ...,  0.89506071,
          -0.35446518,  0.00640204],
         [-0.78529401,  0.92398449, -0.90642682, ...,  0.89506071,
          -0.35446518,  0.00640204],
         [-0.7526425 ,  0.8610451 , -0.94581723, ..., -0.06008591,
          -0.99808797,  0.04198902]]),
  array([[-2.01222612e-01, -3.46018287e-01, -1.00183879e+00, ...,
           8.78542604e-01,  4.25447346e-01,  2.50529974e-05],
         [-1.22905692e-01, -4.20159158e-01, -8.02625360e-01, ...,
           3.02165508e-01,  6.96052903e-02,  8.90575816e-05],
         [-1.29358891e-01, -4.20507161e-01, -8.13636383e-01, ...,
           3.02165508e-01,  6.96052903

In [10]:
len(train_couples_0)

1467881

In [13]:
val_couples_0 = [item for item in val_couples if item[2] == 0]
val_couples_1 = [item for item in val_couples if item[2] == 1]


val_couples_per_epoch = 8000

In [15]:
len(val_couples_1)

40076

In [None]:
# let's exract only some couples to make the train cycle faster
el_1 = [[item[0], item[1],item[2]] for item in train_couples if item[2] == 1]
el_0 = [[item[0], item[1],item[2]] for item in train_couples if item[2] == 0]
el_0_s = el_0[0:100000]
el_1_s = el_1[0:100000]

train_list = el_1_s + el_0_s
random.shuffle(train_list)



val_1 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 1]
val_0 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 0]
val_0_s = val_0[0:10000]
val_1_s = val_1[0:10000]

val_list = val_1_s + val_0_s
random.shuffle(val_list)

In [None]:
with open('train2000k.pickle', 'wb') as file:
    pickle.dump(train_list, file)

with open('val2000k.pickle', 'wb') as file:
    pickle.dump(val_list, file)    

In [7]:
# 4500 couples
with open('train.pickle', 'rb') as file:
    train_list = pickle.load(file)

with open('val.pickle', 'rb') as file:
    val_list = pickle.load(file)    

In [None]:
# 8000 couples
with open('train2.pickle', 'rb') as file:
    train_list = pickle.load(file)

with open('val2.pickle', 'rb') as file:
    val_list = pickle.load(file)  

In [None]:
# 16000 couples
with open('train3.pickle', 'rb') as file:
    train_list = pickle.load(file)

with open('val3.pickle', 'rb') as file:
    val_list = pickle.load(file)  

In [None]:
# 20000 couples
with open('train4.pickle', 'rb') as file:
    train_list = pickle.load(file)
#2000 couples
with open('val4.pickle', 'rb') as file:
    val_list = pickle.load(file)  

In [9]:
# 2000k couples
with open('train2000k.pickle', 'rb') as file:
    train_list = pickle.load(file)
#2000 couples
with open('val2000k.pickle', 'rb') as file:
    val_list = pickle.load(file)  

In [20]:
len(val_list)

2000

In [27]:
val_1 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 1]
val_0 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 0]
val_0_s = val_0[0:3000]
val_1_s = val_1[0:3000]

val_list = val_1_s + val_0_s
random.shuffle(val_list) 
val_loader_basic = DataLoader(val_list, batch_size=16)    

In [14]:
len(val_list)

6000

In [8]:
#train_loader = DataLoader(train_list, batch_size=16)
#val_loader = DataLoader(val_list, batch_size=16)

In [14]:
device=use_GPU()

NVIDIA GeForce RTX 4080 is available and being used


In [15]:
model = PairModel1().to(device)
model.double()

# Sets the path where the model parameters will be stored.
model_path_ = r'C:\\Users\\Alessandro\\Desktop\\Tesi\\modelpair_weights.pth'

In [16]:
W_stored = torch.load(r'C:\\Users\\Alessandro\\Desktop\\Tesi\\PairModel\\Check_points\\1025_175633_13.pt')

model.load_state_dict(W_stored)

<All keys matched successfully>

In [17]:
wandb.init(
      # Set the project where this run will be logged
      project="vecchia_architettura", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      notes = "vecchia_architettura shuffle  senza constrastive ",
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.00001,
      "architecture": "Model2_mod",
      "epochs": 15,
      "weight_decay": 0.0001,
      "W_crossentropy":1,
      "W_contrastive":0,
      #"type_of_couples": 3,
      #"num_of batch" : len(train_loader),
      "seed": seed,
      "name_saved":'origini_senza_contrastivr'  
      })
      
config = wandb.config

# Calcola i pesi delle classi in base alle percentuali
weight_class_0 = 1.0  # Peso per la classe 0
weight_class_1 = 1.5  # Peso per la classe 1

# Crea un tensore PyTorch con i pesi
weight = torch.tensor([weight_class_0, weight_class_1], dtype=torch.float64).to(device)

criterion = nn.CrossEntropyLoss(weight = weight).to(device)
contrast_criterion = ContrastiveLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
#optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate, weight_decay=config.weight_decay)
num_epochs = config.epochs
best_val_accuracy = 0.0 


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msottile124[0m ([33mpair_fragments[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [34]:
checkpoint_dir = r'C:\\Users\\Alessandro\\Desktop\\Tesi\\PairModel\\Check_points'


checkpoint_interval = 1
epoch_number = 0  


for epoch in range(num_epochs):
    model.train() 

    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    1
    #random.shuffle(train_couples_0)
    #random.shuffle(train_couples_1)

    
    balanced_train_list = []

    
    for _ in range(couples_per_epoch // 2):
        balanced_train_list.append(train_couples_0.pop())
        balanced_train_list.append(train_couples_1.pop())

    
    train_loader = DataLoader(balanced_train_list, batch_size=16,shuffle=True) 
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)

    ###########
    ## Train ##
    ###########


    for batch_data in progress_bar:
        optimizer.zero_grad() 
        frags_a, frags_b, labels = batch_data

        frags_a = apply_randomrotations(frags_a)
        frags_b = apply_randomrotations(frags_b)
        
        frags_a = apply_translation(frags_a)
        frags_b = apply_translation(frags_b)

        frags_a = frags_a.double().to(device)
        frags_b = frags_b.double().to(device)
        labels = labels.to(device)
        
        outputs = model(frags_a, frags_b)
        loss_ = criterion(outputs, labels)
        #contrast_loss = contrast_criterion(frags_a, frags_b, labels)
        loss = loss_ 
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix({'Loss': loss.item(), 'Accuracy': correct_predictions / total_samples})
        

    
    accuracy = correct_predictions / total_samples
    train_loss = total_loss/len(train_loader)

    metrics_train = {"train_loss": train_loss, 
                       "accuracy": accuracy}
    wandb.log(metrics_train) 

    ###############
    ## Inference ##
    ###############

    model.eval()  
    
    val_loss_ = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    val_contrast = 0.0
    with torch.no_grad():
        for val_batch in val_loader_basic:
            val_frags_a, val_frags_b, val_labels = val_batch
            

            val_frags_a = apply_translation(val_frags_a)
            val_frags_b = apply_translation(val_frags_b)

            val_frags_a = val_frags_a.double().to(device)
            val_frags_b = val_frags_b.double().to(device)

            val_labels = val_labels.to(device)
            
            val_outputs = model(val_frags_a, val_frags_b)
            val_loss_ += criterion(val_outputs, val_labels).item()
            #val_contrast += contrast_criterion(val_frags_a, val_frags_b, val_labels).item()
            val_loss = val_loss_ 
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total_samples += val_labels.size(0)
            val_correct_predictions += (val_predicted == val_labels).sum().item()

    val_accuracy = val_correct_predictions / val_total_samples
    val_loss /= len(val_loader_basic)
    val_metrics = {"val_loss": val_loss, 
                       "val_accuracy": val_accuracy}
    wandb.log(val_metrics)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Accuracy: {accuracy:.4f}, '
          f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')


        
    current_time = datetime.datetime.now()
    checkpoint_name = f"{current_time.strftime('%m%d_%H%M%S')}_{epoch + 2}.pt"    
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
    torch.save(model.state_dict(), checkpoint_path)
    

                                                                                         

Epoch [1/15], Training Loss: 0.6638, Training Accuracy: 0.6065, Validation Loss: 0.6365, Validation Accuracy: 0.6375


                                                                                         

Epoch [2/15], Training Loss: 0.6503, Training Accuracy: 0.5983, Validation Loss: 0.6239, Validation Accuracy: 0.6355


                                                                                         

Epoch [3/15], Training Loss: 0.6319, Training Accuracy: 0.6151, Validation Loss: 0.6281, Validation Accuracy: 0.6385


                                                                                         

Epoch [4/15], Training Loss: 0.6292, Training Accuracy: 0.6208, Validation Loss: 0.6164, Validation Accuracy: 0.6508


                                                                                         

Epoch [5/15], Training Loss: 0.6425, Training Accuracy: 0.6155, Validation Loss: 0.6107, Validation Accuracy: 0.6465


                                                                                         

Epoch [6/15], Training Loss: 0.6224, Training Accuracy: 0.6297, Validation Loss: 0.6151, Validation Accuracy: 0.6388


                                                                                         

Epoch [7/15], Training Loss: 0.6201, Training Accuracy: 0.6387, Validation Loss: 0.6102, Validation Accuracy: 0.6467


                                                                                         

Epoch [8/15], Training Loss: 0.6314, Training Accuracy: 0.6206, Validation Loss: 0.6173, Validation Accuracy: 0.6480


                                                                                         

Epoch [9/15], Training Loss: 0.6261, Training Accuracy: 0.6246, Validation Loss: 0.6200, Validation Accuracy: 0.6367


                                                                                          

Epoch [10/15], Training Loss: 0.6376, Training Accuracy: 0.6075, Validation Loss: 0.6139, Validation Accuracy: 0.6332


                                                                                          

Epoch [11/15], Training Loss: 0.6248, Training Accuracy: 0.6265, Validation Loss: 0.6124, Validation Accuracy: 0.6428


                                                                                          

Epoch [12/15], Training Loss: 0.6194, Training Accuracy: 0.6312, Validation Loss: 0.6226, Validation Accuracy: 0.6497


                                                                                          

Epoch [13/15], Training Loss: 0.6374, Training Accuracy: 0.6163, Validation Loss: 0.6107, Validation Accuracy: 0.6433


                                                                                          

Epoch [14/15], Training Loss: 0.6409, Training Accuracy: 0.6110, Validation Loss: 0.6173, Validation Accuracy: 0.6445


                                                                                          

Epoch [15/15], Training Loss: 0.6093, Training Accuracy: 0.6391, Validation Loss: 0.6234, Validation Accuracy: 0.6468


In [19]:
model.eval()  
    
val_loss_ = 0.0
val_correct_predictions = 0
val_total_samples = 0
val_contrast = 0.0
with torch.no_grad():
        for val_batch in val_loader_basic:
            val_frags_a, val_frags_b, val_labels = val_batch
            

            val_frags_a = apply_translation(val_frags_a)
            val_frags_b = apply_translation(val_frags_b)

            val_frags_a = val_frags_a.double().to(device)
            val_frags_b = val_frags_b.double().to(device)

            val_labels = val_labels.to(device)
            
            val_outputs = model(val_frags_a, val_frags_b)
            val_loss_ += criterion(val_outputs, val_labels).item()
            val_contrast += contrast_criterion(val_frags_a, val_frags_b, val_labels).item()
            val_loss = val_loss_ + config.W_contrastive*val_contrast
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total_samples += val_labels.size(0)
            val_correct_predictions += (val_predicted == val_labels).sum().item()

val_accuracy = val_correct_predictions / val_total_samples
val_loss /= len(val_loader_basic)
val_metrics = {"val_loss": val_loss, 
                       "val_accuracy": val_accuracy}
wandb.log(val_metrics)
print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Accuracy: {accuracy:.4f}, '
          f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

Epoch [1/15], Training Loss: 1.3100, Training Accuracy: 0.5447, Validation Loss: 1.2591, Validation Accuracy: 0.6132


In [20]:
current_time = datetime.datetime.now()
checkpoint_name = f"{current_time.strftime('%m%d_%H%M%S')}_{epoch + 1}.pt"    
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
torch.save(model.state_dict(), checkpoint_path)

In [17]:
wandb.finish()

0,1
accuracy,▁█▂
train_loss,█▁▄
val_accuracy,▄█▁
val_loss,█▁▄

0,1
accuracy,0.6008
train_loss,0.64321
val_accuracy,0.62983
val_loss,0.62472


Test


In [None]:
test = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\test_pair_dataset_REG.pt")
test_couples = CreateCouples(test)

test_1 = [[item[0], item[1],item[2]] for item in test_couples if item[2] == 1]
test_0 = [[item[0], item[1],item[2]] for item in test_couples if item[2] == 0]
test_0_s = test_0[0:500]
test_1_s = test_1[0:500]

test_list = test_1_s + test_0_s
random.shuffle(test_list)

100%|██████████| 327/327 [00:00<00:00, 1116.04it/s]


In [None]:
test_loader = DataLoader(test_list, batch_size=16)

In [None]:
device=use_GPU()

NVIDIA GeForce RTX 4080 is available and being used


In [56]:
W_stored = torch.load(r'C:\\Users\\Alessandro\\Desktop\\Tesi\\PairModel\\Check_points\\1025_191445_23.pt')

model.load_state_dict(W_stored)


<All keys matched successfully>

In [61]:
model.eval()
test_correct_predictions = 0
test_total_samples = 0
random.shuffle(val_couples_0)
random.shuffle(val_couples_1)
list_of_results=[]
list_of_true =[]    
balanced_val_list = []

    
#for _ in range(val_couples_per_epoch // 2):
        #balanced_val_list.append(val_couples_0.pop())
        #balanced_val_list.append(val_couples_1.pop())
#random.shuffle(balanced_val_list)        

test_loader = DataLoader(val_couples_c[9], batch_size=16,shuffle = True) 

#progress_bar = tqdm(test_loader,  leave=False)
with torch.no_grad():
        for test_batch in test_loader:
            
            test_frags_a, test_frags_b, test_labels = test_batch
            #print(test_labels)
             
            test_frags_a = apply_translation(test_frags_a)
            test_frags_b = apply_translation(test_frags_b)
            
            test_frags_a = test_frags_a.double().to(device)
            test_frags_b = test_frags_b.double().to(device)

            test_labels = test_labels.to(device)
            list_of_true.append(test_labels.tolist())
            test_outputs = model(test_frags_a, test_frags_b)
                        
            _, test_predicted = torch.max(test_outputs.data, 1)
            test_total_samples += test_labels.size(0)
            test_correct_predictions += (test_predicted == test_labels).sum().item()
            list_of_results.append(test_predicted .tolist())
            #progress_bar.set_postfix({ 'Accuracy': test_correct_predictions / test_total_samples})
test_accuracy = test_correct_predictions / test_total_samples
print(f'Test Accuracy: {test_accuracy:.4f}')

Test Accuracy: 0.6169


In [None]:
#0.6369

In [53]:
len(val_couples_c)

327