In [11]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import random
import pickle
import datetime
import os

In [2]:
def use_GPU():
    """ This function activates the gpu 
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(torch.cuda.get_device_name(0), "is available and being used")
    else:
        device = torch.device("cpu")
        print("GPU is not available, using CPU instead") 
    return device  



def CreateCouples(pt):
    
    """ This function, modifies the shape of the tensor to fit the model 
    
    """
    clusters= pt.shape[0]
    couples = []
    # for each subcluster
    for i in tqdm(range(0,clusters)):
        
        # discover the number of fragments
        n_frags = pt[i][0].shape[0]

        # exract the adj matrix
        matr= pt[i][0]

        # exract the cluster of fragments
        data = pt[i][1]
        
        
        
        for j in range(0,n_frags -1):
            
            init = j+1
            for k in range(init,n_frags): 

             couples.append([data[j], data[k], matr[j][k]])
    return couples 

In [9]:
# https://github.com/qq456cvb/Point-Transformers

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)


def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        distance = torch.min(distance, dist)
        farthest = torch.max(distance, -1)[1]
    return centroids

def sample_and_group(npoint, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint 
    
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]

    new_xyz = index_points(xyz, fps_idx) 
    new_points = index_points(points, fps_idx)

    dists = square_distance(new_xyz, xyz)  # B x npoint x N
    idx = dists.argsort()[:, :, :nsample]  # B x npoint x K

    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
    return new_xyz, new_points


class Local_op(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        b, n, s, d = x.size()  # torch.Size([32, 512, 32, 6]) 
        x = x.permute(0, 1, 3, 2)
        x = x.reshape(-1, d, s)
        batch_size, _, N = x.size()
        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
        x = torch.max(x, 2)[0]
        x = x.view(batch_size, -1)
        x = x.reshape(b, n, -1).permute(0, 2, 1)
        return x


class SA_Layer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight 
        self.v_conv = nn.Conv1d(channels, channels, 1)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 
        x_k = self.k_conv(x)# b, c, n        
        x_v = self.v_conv(x)
        energy = x_q @ x_k # b, n, n 
        attention = self.softmax(energy)
        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
        x_r = x_v @ attention # b, c, n 
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        x = x + x_r
        return x
    

class StackedAttention(nn.Module):
    def __init__(self, channels=256):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm1d(channels)
        self.bn2 = nn.BatchNorm1d(channels)

        self.sa1 = SA_Layer(channels)
        self.sa2 = SA_Layer(channels)
        self.sa3 = SA_Layer(channels)
        self.sa4 = SA_Layer(channels)

        self.relu = nn.ReLU()
        
    def forward(self, x):
        # 
        # b, 3, npoint, nsample  
        # conv2d 3 -> 128 channels 1, 1
        # b * npoint, c, nsample 
        # permute reshape
        batch_size, _, N = x.size()

        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x)))

        x1 = self.sa1(x)
        x2 = self.sa2(x1)
        x3 = self.sa3(x2)
        x4 = self.sa4(x3)
        
        x = torch.cat((x1, x2, x3, x4), dim=1)

        return x

In [8]:

class Branch(nn.Module):
    def __init__(self):
        super().__init__()
        
        d_points = 7 # we have 7 features for each point
        self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
        self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
        self.pt_last = StackedAttention()

        self.relu = nn.ReLU()
        self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(negative_slope=0.2))

        
    def forward(self, x):
        xyz = x[..., :3]
        x = x.permute(0, 2, 1)
        batch_size, _, _ = x.size()
        x= x.double()
        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
        x = x.permute(0, 2, 1)
        new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)         
        feature_0 = self.gather_local_0(new_feature)
        feature = feature_0.permute(0, 2, 1)
        new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 
        feature_1 = self.gather_local_1(new_feature)
        
        x = self.pt_last(feature_1)
        x = torch.cat([x, feature_1], dim=1)
        x = self.conv_fuse(x)
        x = torch.max(x, 2)[0] # Returns the maximum value of all elements in the input tensor. (2 elementes for each vector)
        x = x.view(batch_size, -1) # Returns a new tensor with the same data as the self tensor but of a different shape.
        
        return x
    
    
class PairModel1(nn.Module):
    def __init__(self):
        super().__init__()
        
        output_channels = 2 # it's a binary classification

        self.branch1 = Branch()
        self.branch2 = Branch()

        self.relu = nn.ReLU()

        # classificator
        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn1 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(256, output_channels)
        
    def forward(self, batch_1, batch_2):
        
        x_1 = self.branch1(batch_1)
        x_2 = self.branch2(batch_2)

        x = x_1 + x_2 # let's sum the output of the two branches 

        # classificator
        x = self.relu(self.bn1(self.linear1(x)))
        x = self.dp1(x)
        x = self.relu(self.bn2(self.linear2(x)))
        x = self.dp2(x)
        x = self.linear3(x)
        
        return x


In [5]:
# data loading
train = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\train_pair_dataset_REG.pt")
val = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\val_pair_dataset_REG.pt")
#test = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\test_pair_dataset_REG.pt")

In [6]:
# Change the shape of the data
train_couples = CreateCouples(train)
val_couples = CreateCouples(val)
#test_couples = CreateCouples(test)

100%|██████████| 1526/1526 [00:01<00:00, 1237.63it/s]
100%|██████████| 327/327 [00:00<00:00, 1162.12it/s]


In [7]:
len(train_couples)

1713243

In [8]:
# let's exract only some couples to make the train cycle faster
el_1 = [[item[0], item[1],item[2]] for item in train_couples if item[2] == 1]
el_0 = [[item[0], item[1],item[2]] for item in train_couples if item[2] == 0]
el_0_s = el_0[0:4000]
el_1_s = el_1[0:4000]

train_list = el_1_s + el_0_s
random.shuffle(train_list)



val_1 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 1]
val_0 = [[item[0], item[1],item[2]] for item in val_couples if item[2] == 0]
val_0_s = val_0[0:500]
val_1_s = val_1[0:500]

val_list = val_1_s + val_0_s
random.shuffle(val_list)

In [9]:
with open('train2.pickle', 'wb') as file:
    pickle.dump(train_list, file)

with open('val2.pickle', 'wb') as file:
    pickle.dump(val_list, file)    

In [4]:
# 4500 couples
with open('train.pickle', 'rb') as file:
    train_list = pickle.load(file)

with open('val.pickle', 'rb') as file:
    val_list = pickle.load(file)    

In [10]:
# 8000 couples
with open('train2.pickle', 'rb') as file:
    train_list = pickle.load(file)

with open('val2.pickle', 'rb') as file:
    val_list = pickle.load(file)  

In [6]:
train_loader = DataLoader(train_list, batch_size=32)
val_loader = DataLoader(val_list, batch_size=32)

In [13]:
device=use_GPU()

NVIDIA GeForce RTX 4080 is available and being used


In [12]:
model = PairModel1().to(device)
model.double()

# Sets the path where the model parameters will be stored.
model_path_ = r'C:\\Users\\Alessandro\\Desktop\\Tesi\\modelpair_weights.pth'

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, weight_decay=0.001)
num_epochs = 10
best_val_accuracy = 0.0 

checkpoint_dir = r'C:\\Users\\Alessandro\\Desktop\\Tesi\\PairModel\\Check_points'


checkpoint_interval = 3
epoch_number = 0  


for epoch in range(num_epochs):
    model.train() 

    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)

    ###########
    ## Train ##
    ###########


    for batch_data in progress_bar:
        optimizer.zero_grad() 
        frags_a, frags_b, labels = batch_data

        frags_a = frags_a.double().to(device)
        frags_b = frags_b.double().to(device)
        labels = labels.to(device)
        
        outputs = model(frags_a, frags_b)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        progress_bar.set_postfix({'Loss': loss.item(), 'Accuracy': correct_predictions / total_samples})
        

    
    accuracy = correct_predictions / total_samples
    train_loss = total_loss/len(train_loader)


    ###############
    ## Inference ##
    ###############

    model.eval()  
    
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    
    with torch.no_grad():
        for val_batch in val_loader:
            val_frags_a, val_frags_b, val_labels = val_batch
            
            val_frags_a = val_frags_a.double().to(device)
            val_frags_b = val_frags_b.double().to(device)

            val_labels = val_labels.to(device)
            
            val_outputs = model(val_frags_a, val_frags_b)
            val_loss += criterion(val_outputs, val_labels).item()
            
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total_samples += val_labels.size(0)
            val_correct_predictions += (val_predicted == val_labels).sum().item()

    val_accuracy = val_correct_predictions / val_total_samples
    val_loss /= len(val_loader)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Training Accuracy: {accuracy:.4f}, '
          f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')


    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), model_path_)



    if (epoch + 1) % checkpoint_interval == 0:
        
        current_time = datetime.datetime.now()
        checkpoint_name = f"{current_time.strftime('%Y%m%d_%H%M%S')}_{epoch + 1}.pt"

        
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
        torch.save(model.state_dict(), checkpoint_path)
    epoch_number += 1    

                                                                                         

Epoch [1/10], Training Loss: 0.6735, Training Accuracy: 0.6122, Validation Loss: 0.6497, Validation Accuracy: 0.6240


                                                                                         

Epoch [2/10], Training Loss: 0.5603, Training Accuracy: 0.7129, Validation Loss: 0.6544, Validation Accuracy: 0.6170


                                                                                         

Epoch [3/10], Training Loss: 0.4979, Training Accuracy: 0.7727, Validation Loss: 0.6639, Validation Accuracy: 0.6030


                                                                                         

Epoch [4/10], Training Loss: 0.4318, Training Accuracy: 0.8109, Validation Loss: 0.6850, Validation Accuracy: 0.5990


                                                                                         

Epoch [5/10], Training Loss: 0.3726, Training Accuracy: 0.8547, Validation Loss: 0.7091, Validation Accuracy: 0.5930


                                                                                         

Epoch [6/10], Training Loss: 0.3192, Training Accuracy: 0.8858, Validation Loss: 0.7424, Validation Accuracy: 0.5810


                                                                                         

Epoch [7/10], Training Loss: 0.2837, Training Accuracy: 0.8987, Validation Loss: 0.7706, Validation Accuracy: 0.5640


                                                                                         

Epoch [8/10], Training Loss: 0.2488, Training Accuracy: 0.9131, Validation Loss: 0.8258, Validation Accuracy: 0.5640


                                                                                         

Epoch [9/10], Training Loss: 0.2192, Training Accuracy: 0.9309, Validation Loss: 0.8999, Validation Accuracy: 0.5430


                                                                                          

Epoch [10/10], Training Loss: 0.1974, Training Accuracy: 0.9391, Validation Loss: 0.9660, Validation Accuracy: 0.5320


Test


In [15]:
test = torch.load("C:\\Users\\Alessandro\\Desktop\\Tesi\\pair_dataset\\dataset_1024_AB\\test_pair_dataset_REG.pt")
test_couples = CreateCouples(test)

test_1 = [[item[0], item[1],item[2]] for item in test_couples if item[2] == 1]
test_0 = [[item[0], item[1],item[2]] for item in test_couples if item[2] == 0]
test_0_s = test_0[0:500]
test_1_s = test_1[0:500]

test_list = test_1_s + test_0_s
random.shuffle(test_list)

100%|██████████| 327/327 [00:00<00:00, 962.06it/s]


In [16]:
test_loader = DataLoader(test_list, batch_size=32)

In [17]:
device=use_GPU()

NVIDIA GeForce RTX 4080 is available and being used


In [18]:
W_stored = torch.load(model_path_)
model.load_state_dict(W_stored)

<All keys matched successfully>

In [19]:
model.eval()
test_correct_predictions = 0
test_total_samples = 0

with torch.no_grad():
        for test_batch in test_loader:
            test_frags_a, test_frags_b, test_labels = test_batch
            
            test_frags_a = test_frags_a.double().to(device)
            test_frags_b = test_frags_b.double().to(device)

            test_labels = test_labels.to(device)
            
            test_outputs = model(test_frags_a, test_frags_b)
            
            
            _, test_predicted = torch.max(test_outputs.data, 1)
            test_total_samples += test_labels.size(0)
            test_correct_predictions += (test_predicted == test_labels).sum().item()

test_accuracy = test_correct_predictions / test_total_samples
print(f'Test Accuracy: {test_accuracy:.4f}')

Test Accuracy: 0.5710
