In [None]:
from pointnet2_utils import PointNetSetAbstraction, PointNetSetAbstractionMsg, PointNetFeaturePropagation, Attention, GLUBlock, LightHead
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import torch
import collections
from scipy.linalg import expm,norm
import glob
import open3d as o3d
import struct
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
import socket
print("Running on:", socket.gethostname())

Running on: PHUC


<h2> Build PointNet++ </h2>

In [2]:
# helper functions:
def farthest_point_sample(xyz, npoint):
    batch_size, n, _ = xyz.shape
    device = xyz.device
    centroids = torch.zeros(batch_size, npoint, dtype=torch.long).to(device)
    distance = torch.ones(batch_size, n).to(device) * 1e10
    farthest = torch.randint(0, n, (batch_size,), dtype=torch.long).to(device)
    batch_indices = torch.arange(batch_size, dtype=torch.long).to(device)

    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].unsqueeze(1)  # [B, 1, 3]
        dist = torch.sum((xyz - centroid) ** 2, -1)  # Squared distance [B, N]
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]  # Index of the farthest point

    return centroids


def gather_points(xyz, idx):
    batch_size, n, _ = xyz.shape

    if idx.dim() == 2:  # Case 1: [B, npoint]
        _, npoint = idx.shape
        idx = idx.view(-1)  # Flatten indices for batch processing
        gathered_xyz = xyz.reshape(batch_size * n, -1)[idx, :]  # Gather points
        gathered_xyz = gathered_xyz.reshape(batch_size, npoint, -1)  # Reshape back
    elif idx.dim() == 3:  # Case 2: [B, npoint, nsample]
        _, npoint, nsample = idx.shape
        idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * n
        idx = idx + idx_base  # Flatten indices for batch processing
        idx = idx.reshape(-1)  # Flatten completely
        gathered_xyz = xyz.reshape(batch_size * n, -1)[idx, :]  # Gather points
        gathered_xyz = gathered_xyz.reshape(batch_size, npoint, nsample, -1)  # Reshape back

    return gathered_xyz


def query_and_group(xyz, new_xyz, points, radius, nsample):
    B, N, _ = xyz.shape
    _, npoint, _ = new_xyz.shape

    # Compute squared distances between sampled points and all points
    sqrdists = square_distance(new_xyz, xyz)  # [B, npoint, N]

    # Find indices of the nearest neighbors
    group_idx = sqrdists.argsort(dim=-1)[:, :, :nsample]  # [B, npoint, nsample]

    # Gather the grouped xyz coordinates
    grouped_xyz = gather_points(xyz, group_idx)  # [B, npoint, nsample, 3]
    grouped_xyz = grouped_xyz - new_xyz.unsqueeze(2)  # Local coordinates [B, npoint, nsample, 3]

    if points is not None:
        grouped_points = gather_points(points.transpose(1, 2), group_idx).permute(0, 3, 2, 1)  # [B, C, nsample, npoint]
        new_points = torch.cat([grouped_xyz.permute(0, 3, 2, 1), grouped_points], dim=1)  # [B, C+3, nsample, npoint]
    else:
        new_points = grouped_xyz.permute(0, 3, 2, 1)  # [B, 3, nsample, npoint]

    return new_points


def square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))  # [B, N, M]
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


class SetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super(SetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample

        # Include x, y, z coordinates in the input channel count
        last_channel = in_channel + 3  # Add (x, y, z)
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))  # Conv2D expects [B, C_in, nsample, npoint]
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points):
        if self.npoint is not None:
            idx = farthest_point_sample(xyz, self.npoint)  # [B, npoint]
            new_xyz = gather_points(xyz, idx)  # [B, npoint, 3]
        else:
            new_xyz = xyz  # Use all points if npoint is None

        grouped_points = query_and_group(xyz, new_xyz, points, self.radius, self.nsample)  # [B, C+3, nsample, npoint]
        for i, conv in enumerate(self.mlp_convs):
            grouped_points = F.relu(self.mlp_bns[i](conv(grouped_points)))  # [B, out_channel, nsample, npoint]

        # Max pooling over nsample dimension
        new_points = torch.max(grouped_points, 2)[0]  # [B, mlp[-1], npoint]
        return new_xyz, new_points
    
    
class Attention(nn.Module):
    """
    Lớp self-attention cho đặc trưng điểm trong point cloud.
    Dùng sau mỗi tầng Set Abstraction để tăng hiệu suất mô hình học hình học.
    Input: [B, C, N] (batch, channel, num_points)
    Output: [B, C, N]
    """
    def __init__(self, in_channels, heads=4):
        super(Attention, self).__init__()
        self.in_channels = in_channels
        self.heads = heads
        self.dk = in_channels // heads
        assert in_channels % heads == 0, "in_channels phải chia hết cho số heads"
        self.query = nn.Conv1d(in_channels, in_channels, 1)
        self.key = nn.Conv1d(in_channels, in_channels, 1)
        self.value = nn.Conv1d(in_channels, in_channels, 1)
        self.proj = nn.Conv1d(in_channels, in_channels, 1)

    def forward(self, x):
        # x: [B, C, N]
        B, C, N = x.shape
        Q = self.query(x).view(B, self.heads, self.dk, N)  # [B, heads, dk, N]
        K = self.key(x).view(B, self.heads, self.dk, N)
        V = self.value(x).view(B, self.heads, self.dk, N)
        attn = torch.einsum('bhdk,bhdk->bhdn', Q, K) / (self.dk ** 0.5)  # [B, heads, N, N]
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('bhdn,bhdn->bhdk', attn, V)  # [B, heads, dk, N]
        out = out.contiguous().view(B, C, N)
        out = self.proj(out)
        return out + x  # residual
    

class GLUBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear_main = nn.Linear(in_dim, out_dim)
        self.linear_gate = nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return self.linear_main(x) * torch.sigmoid(self.linear_gate(x))

# Define PointNet++ model
class PointNetPlusPlus(nn.Module):
    def __init__(self, num_classes):
        super(PointNetPlusPlus, self).__init__()

        # Set Abstraction layers
        self.sa1 = SetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=0, mlp=[64,128])
        self.sa1_attention = Attention(in_channels=128, heads=4)  # Attention after first SA layer
        self.sa2 = SetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128, mlp=[128, 256])
        self.sa2_attention = Attention(in_channels=256, heads=4)  # Attention after second SA layer
        self.sa3 = SetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256, mlp=[256,512, 1024])
        self.sa3_attention = Attention(in_channels=1024, heads=4)  # Attention after third SA layer

        # Fully connected layers for classification
        self.light_head = LightHead(in_dim=1024, num_classes=num_classes)

    def forward(self, xyz):
        batch_size, _, _ = xyz.shape

        # Hierarchical feature extraction
        l1_xyz, l1_points = self.sa1(xyz, None)       # Layer 1: [B, 512, 128]
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)  # Layer 2: [B, 128, 256]
        _, l3_points = self.sa3(l2_xyz, l2_points)    # Layer 3: [B, 1024, npoint]
        # Fully connected layers
        x = self.light_head(l3_points)
        return F.log_softmax(x, dim=1)






<h2> Dataset utils </h2>

In [3]:
def extract_unique_labels(label_dir):
    unique_labels = set()
    for label_file in os.listdir(label_dir):
        if label_file.endswith('.txt'):
            with open(os.path.join(label_dir, label_file), 'r') as file:
                for line in file:
                    parts = line.strip().split()
                    unique_labels.add(parts[0])  # Add the label (Class)
    return sorted(unique_labels)

In [4]:
def bin_to_pcd(binFileName):
    size_float = 4
    list_pcd = []
    with open(binFileName, "rb") as f:
        byte = f.read(size_float * 4)
        while byte:
            x, y, z, intensity = struct.unpack("ffff", byte)
            list_pcd.append([x, y, z])
            byte = f.read(size_float * 4)
    np_pcd = np.asarray(list_pcd)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np_pcd)
    return pcd

In [None]:
velodyne_dir = r"D:\Workspace\Python\Point_Cloud_3D_Object_Detection\data\training\velodyne"
label_dir = r"D:\Workspace\Python\Point_Cloud_3D_Object_Detection\data\training\label_2"
calib_dir = r"D:\Workspace\Python\Point_Cloud_3D_Object_Detection\data\training\calib"
unique_labels = extract_unique_labels(label_dir)
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
print("Extracted label mapping:", label_to_id)

Extracted label mapping: {'Car': 0, 'Cyclist': 1, 'DontCare': 2, 'Misc': 3, 'Pedestrian': 4, 'Person_sitting': 5, 'Tram': 6, 'Truck': 7, 'Van': 8}


In [6]:
def read_velodyne_bin(bin_path):
    """
    Đọc file .bin từ KITTI và trả về mảng (N, 4): x, y, z, reflectance
    """
    return np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4)

In [7]:
def read_kitti_label(label_file):
    """
    Đọc file .txt nhãn từ KITTI object detection.
    Trả về danh sách bounding box + class:
    [class, x, y, z, h, w, l, ry]
    """
    if not label_file.endswith('.txt'):
        return None
    boxes = []
    with open(label_file, 'r') as f:
        for line in f:
            if line.strip() == '':
                continue
            parts = line.strip().split(' ')
            cls = parts[0]
            if cls == 'DontCare':
                continue
            # Extract 3D box info
            h, w, l = map(float, parts[8:11])
            x, y, z = map(float, parts[11:14])
            ry = float(parts[14])
            boxes.append({
                'class': cls,
                'center': [x, y, z],
                'size': [l, w, h],
                'rotation': ry
            })
    return boxes

In [8]:
def read_calib_file(calib_path):
    """
    Đọc file calibration của KITTI và trả về các ma trận chuyển đổi
    """
    data = {}
    with open(calib_path, 'r') as f:
        for line in f.readlines():
            if ':' in line:
                key, value = line.split(':', 1)
                data[key] = np.array([float(x) for x in value.strip().split()])
    
    # Chuyển về ma trận đúng shape
    data['Tr_velo_to_cam'] = data['Tr_velo_to_cam'].reshape(3, 4)
    data['R0_rect'] = data['R0_rect'].reshape(3, 3)
    return data

In [9]:
def cam_to_velo(xyz_cam, calib):
    """
    Chuyển đổi tọa độ từ camera sang lidar (velodyne)
    xyz_cam: (N, 3) - tọa độ trong hệ camera
    calib: dict chứa các ma trận calibration
    """
    # Thêm 1 vào cuối để thành (N, 4) - homogeneous coordinates
    xyz_cam_hom = np.hstack([xyz_cam, np.ones((xyz_cam.shape[0], 1))])
    
    # Lấy ma trận chuyển đổi từ velodyne sang camera
    Tr = calib['Tr_velo_to_cam']  # (3, 4)
    
    # Tính ma trận nghịch đảo để chuyển từ camera sang velodyne
    Tr_inv = np.linalg.pinv(np.vstack([Tr, [0,0,0,1]]))  # (4,4)
    
    # Chuyển đổi tọa độ
    xyz_velo = (Tr_inv @ xyz_cam_hom.T).T[:, :3]
    return xyz_velo

In [10]:
def convert_3d_box_to_velo(box, calib):
    """
    Chuyển đổi 3D bounding box từ hệ camera sang hệ lidar
    box: dict chứa thông tin box {'center': [x,y,z], 'size': [l,w,h], 'rotation': ry}
    calib: dict chứa các ma trận calibration
    """
    center_cam = np.array([box['center']])  # (1, 3)
    center_velo = cam_to_velo(center_cam, calib)[0]  # (3,)
    
    # Kích thước box không đổi khi chuyển hệ tọa độ
    size_velo = box['size']  # [l, w, h]
    
    # Góc quay cần điều chỉnh (thường không đổi nhiều)
    rotation_velo = box['rotation']
    
    return {
        'center': center_velo.tolist(),
        'size': size_velo,
        'rotation': rotation_velo,
        'class': box['class']
    }

In [11]:
def extract_objects_from_pointcloud_with_calib(points, bboxes, class_map, calib):
    """
    Trích xuất object từ point cloud với chuyển đổi hệ tọa độ
    """
    objects = []
    for box in bboxes:
        cls = box['class']
        if cls not in class_map:
            continue
            
        # Chuyển đổi box từ camera sang lidar
        box_velo = convert_3d_box_to_velo(box, calib)
        
        center, size = box_velo['center'], box_velo['size']
        l, w, h = size
        x, y, z = center

        # Hộp trục song song (AABB) trong hệ lidar
        mask = (
            (points[:, 0] > x - l/2) & (points[:, 0] < x + l/2) &
            (points[:, 1] > y - w/2) & (points[:, 1] < y + w/2) &
            (points[:, 2] > z - h/2) & (points[:, 2] < z + h/2)
        )
        pc_object = points[mask][:, :3]
        
        if len(pc_object) >= 30:  # Chỉ lấy object có đủ điểm
            label_id = class_map[str(cls)]
            objects.append((pc_object, label_id))
            
    return objects

In [12]:
def extract_all_objects_with_calib(velodyne_dir, label_dir, calib_dir, class_map):
    """
    Trích xuất tất cả object từ point cloud với chuyển đổi hệ tọa độ
    """
    all_objects = []  # List chứa (pc_object, class_id)
    bin_files = sorted(glob.glob(os.path.join(velodyne_dir, "*.bin")))

    for bin_path in bin_files:
        file_id = os.path.splitext(os.path.basename(bin_path))[0]  # '000012'

        # Đường dẫn đến file label và calib tương ứng
        label_path = os.path.join(label_dir, f"{file_id}.txt")
        calib_path = os.path.join(calib_dir, f"{file_id}.txt")
        
        if not os.path.exists(label_path):
            print(f"[!] Thiếu label cho {file_id}, bỏ qua")
            continue
            
        if not os.path.exists(calib_path):
            print(f"[!] Thiếu calib cho {file_id}, bỏ qua")
            continue

        # Đọc dữ liệu
        points = read_velodyne_bin(bin_path)
        bboxes = read_kitti_label(label_path)
        calib = read_calib_file(calib_path)
        
        # Trích xuất object với chuyển đổi hệ tọa độ
        objects = extract_objects_from_pointcloud_with_calib(points, bboxes, class_map, calib)
        all_objects.extend(objects)
        
    return all_objects

In [13]:
# Hàm test để kiểm tra chuyển đổi
def test_coordinate_conversion():
    """
    Hàm test để kiểm tra việc chuyển đổi hệ tọa độ
    """
    # Đường dẫn test
    velodyne_dir = r"E:\Storange\Python\Point_cloud\data\archive\training\velodyne_subset"
    label_dir = r"E:\Storange\Python\Point_cloud\data\archive\training\label_2_subset"
    calib_dir = r"E:\Storange\Python\Point_cloud\data\archive\training\calib_subset"
    
    # Test với file đầu tiên
    test_file_id = "000000"
    
    # Đọc dữ liệu
    points = read_velodyne_bin(os.path.join(velodyne_dir, f"{test_file_id}.bin"))
    bboxes = read_kitti_label(os.path.join(label_dir, f"{test_file_id}.txt"))
    calib = read_calib_file(os.path.join(calib_dir, f"{test_file_id}.txt"))
    
    print(f"Point cloud shape: {points.shape}")
    print(f"Số bounding boxes: {len(bboxes)}")
    print(f"Calib keys: {list(calib.keys())}")
    
    # Test chuyển đổi một box
    if len(bboxes) > 0:
        box = bboxes[0]
        print(f"\nBox gốc (camera): {box}")
        
        box_velo = convert_3d_box_to_velo(box, calib)
        print(f"Box sau chuyển đổi (lidar): {box_velo}")
        
        # Test trích xuất object
        class_map = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
        objects = extract_objects_from_pointcloud_with_calib(points, bboxes, class_map, calib)
        print(f"Số object trích xuất được: {len(objects)}")
        
        for i, (pc_obj, label_id) in enumerate(objects):
            print(f"Object {i}: class_id={label_id}, num_points={len(pc_obj)}")

if __name__ == "__main__":
    test_coordinate_conversion()

Point cloud shape: (115384, 4)
Số bounding boxes: 1
Calib keys: ['P0', 'P1', 'P2', 'P3', 'R0_rect', 'Tr_velo_to_cam', 'Tr_imu_to_velo']

Box gốc (camera): {'class': 'Pedestrian', 'center': [1.84, 1.47, 8.41], 'size': [1.2, 0.48, 1.89], 'rotation': 0.01}
Box sau chuyển đổi (lidar): {'center': [8.753024291915839, -1.7997219565910108, -1.5464078787587405], 'size': [1.2, 0.48, 1.89], 'rotation': 0.01, 'class': 'Pedestrian'}
Số object trích xuất được: 1
Object 0: class_id=1, num_points=194


<h2> Class Dataset </h2>




In [14]:
all_objs = extract_all_objects_with_calib(velodyne_dir, label_dir,  calib_dir,label_to_id)

print(f"Tổng số object trích ra: {len(all_objs)}")
# all_objs = [(pc1, label1), (pc2, label2), ...]

Tổng số object trích ra: 3263


In [15]:
for pc_obj, label in all_objs:
    print(pc_obj.shape)
    break

(194, 3)


In [16]:
class KittiObjectDataset(Dataset):
    def __init__(self,all_objs,num_points= 1024) -> None:
        self.all_objs = all_objs

        self.num_points = num_points
        
    def __len__(self):
        return len(self.all_objs)
    
    def __getitem__(self, index):
        pc_obj, label = self.all_objs[index]

        if len(pc_obj) > self.num_points:
            idxs = np.random.choice(len(pc_obj), self.num_points, replace=False)
        else:
            idxs = np.random.choice(len(pc_obj), self.num_points, replace=True)
        pc_obj = pc_obj[idxs]

        return torch.tensor(pc_obj, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [None]:
# Training Parameters
from torch import optim
NUM_POINTS= 1024
NUM_CLASSES = len(label_to_id)
EPOCHS = 20
BATCH_SIZE = 32

# Initialize model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PointNetPlusPlus(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

In [None]:

from prettytable import PrettyTable

table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
    if not parameter.requires_grad: continue
    params = parameter.numel()
    table.add_row([name, params])
    total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")

In [None]:
from torch.utils.data import random_split

dataset = KittiObjectDataset(all_objs, num_points=NUM_POINTS)

len_dataset = len(dataset)
print(f"Tổng số mẫu trong dataset: {len_dataset}")

train_dataset, val_dataset, test_dataset = random_split(dataset,
                                          [round(0.7*len_dataset), round(0.1*len_dataset), round(0.2*len_dataset)],
                                          generator=torch.Generator().manual_seed(42))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
for batch in train_dataloader:
    points, labels = batch
    print(points.shape)
    print(labels.shape)
    break

In [None]:
train_loss = []
test_loss = []
train_acc = []
test_acc = []
best_loss= np.inf

def training_loop(epochs, model, train_loader,val_dataloader, optimizer, criterion,num_points):
    """
    Hàm huấn luyện mô hình PointNet++ với dữ liệu point cloud từ KITTI.
    """
    model.train()
    correct = 0
    total = 0
    for epoch in tqdm(range(epochs)):
        epoch_train_loss = []
        epoch_train_acc = []

        for points, labels in train_loader:
            points, labels = points.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(points)  # No need for reshaping

            loss = criterion(outputs, labels)

            epoch_train_loss.append(loss.cpu().item())
            loss.backward()

            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            
            total += labels.size(0)

            accuracy = float(total.item() / correct)
            epoch_train_acc.append(accuracy)

        epoch_test_loss = []
        epoch_test_acc = []

        for points, labels in val_dataloader:
            points, labels = points.to(device), labels.to(device)

            model = model.eval()
            outputs = model(points)
            loss = criterion(outputs, labels)
            epoch_test_loss.append(loss.cpu().item())
            _, predicted = torch.max(outputs, 1)
            correct = (predicted == labels).sum().item()
            accuracy = float(correct / labels.size(0))
            epoch_test_acc.append(accuracy)
    print('Epoch %s: train loss: %s, val loss: %f, train accuracy: %s,  val accuracy: %f'
              % (epoch,
                round(np.mean(epoch_train_loss), 4),
                round(np.mean(epoch_test_loss), 4),
                round(np.mean(epoch_train_acc), 4),
                round(np.mean(epoch_test_acc), 4)))
    if np.mean(test_loss) < best_loss:
        state = {
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict()
        }
        torch.save(state, os.path.join('checkpoints', '3DKitti_checkpoint_%s.pth' % (num_points)))
        best_loss=np.mean(test_loss)
    train_loss.append(np.mean(epoch_train_loss))
    test_loss.append(np.mean(epoch_test_loss))
    train_acc.append(np.mean(epoch_train_acc))
    test_acc.append(np.mean(epoch_test_acc))
    return train_loss, train_acc,test_loss,test_acc


In [None]:
train_loss, train_acc,test_loss,test_acc= training_loop(EPOCHS, model, train_dataloader,val_dataloader, optimizer, criterion,NUM_POINTS)

In [None]:
def plot_losses(train_loss, test_loss, save_to_file=None):
    fig = plt.figure()
    epochs = len(train_loss)
    plt.plot(range(epochs), train_loss, 'b', label='Training loss')
    plt.plot(range(epochs), test_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()
    if save_to_file:
        fig.savefig(save_to_file,dpi=200)

def plot_accuracy(train_acc, test_acc, save_to_file=None):
    fig = plt.figure()
    epochs = len(train_acc)
    plt.plot(range(epochs), train_acc, 'b', label='Training accuracy')
    plt.plot(range(epochs), test_acc, 'r', label='Validation accuracy')
    plt.title('Training and validation accuracy')
    plt.legend()
    if save_to_file:
        fig.savefig(save_to_file,dpi=200)


In [None]:
plot_losses(train_loss, test_loss, save_to_file=False)
plot_accuracy(train_acc, test_acc, save_to_file=False)

<h2> Testing </h2>

In [None]:
# # Ensure model is in evaluation mode
# model.eval()

# # Initialize a list to store predictions
# predictions = []

# with torch.no_grad():
#     for points in test_loader:
#         points = points.to(device)  # Send points to GPU if available
#         outputs = model(points)  # Get predictions
#         _, predicted_classes = torch.max(outputs, 1)  # Predicted class indices
#         predictions.append(predicted_classes.cpu().numpy())  # Store predictions

# # Flatten predictions into a single array
# predictions = np.concatenate(predictions, axis=0)