In [1]:
# --- 1. Import Libraries ---
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch3d.datasets import ShapeNetCore
from pointRTD import PointRTDModel  # Make sure pointRTD.py is in the same directory
import os

In [2]:
# --- 2. Define Hyperparameters ---
BATCH_SIZE = 64
EPOCHS = 300
CORRUPTION_RATIO = 0.6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = f"./checkpoints/Pretrain_PointRTD/CR_{CORRUPTION_RATIO}"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print("Device: ", DEVICE)

Device:  cuda


In [3]:
synset_to_class = {
    "02691156": "airplane",
    "02747177": "ashcan",
    "02773838": "bag",
    "02801938": "bicycle",
    "02808440": "boat",
    "02818832": "bookcase",
    "02828884": "bus",
    "02843684": "cabinet",
    "02871439": "car",
    "02876657": "cellphone",
    "02880940": "chair",
    "02924116": "cone",
    "02933112": "cup",
    "02942699": "bench",
    "02946921": "gun",
    "02954340": "lamp",
    "02958343": "laptop",
    "02992529": "motorcycle",
    "03001627": "piano",
    "03046257": "rifle",
    "03085013": "rocket",
    "03207941": "skateboard",
    "03211117": "sofa",
    "03261776": "table",
    "03325088": "tower",
    "03337140": "train",
    "03467517": "vehicle",
    "03513137": "display",
    "03593526": "washer",
    "03624134": "clock",
    "03636649": "dishwasher",
    "03642806": "earphone",
    "03691459": "firearm",
    "03710193": "furniture",
    "03759954": "fan",
    "03761084": "hat",
    "03790512": "helmet",
    "03797390": "knife",
    "03928116": "lamp",
    "03938244": "loudspeaker",
    "03948459": "mailbox",
    "03991062": "microphone",
    "04004475": "microwave",
    "04074963": "mug",
    "04099429": "pistol",
    "04225987": "pot",
    "04256520": "printer",
    "04330267": "remote",
    "04379243": "bathtub",
    "04401088": "stove",
    "04460130": "dishwasher",
    "04468005": "telephone",
    "04530566": "watercraft",
    "04554684": "guitar"
}


In [4]:
# --- 3. Initialize Dataset and DataLoader ---
import os
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
import numpy as np
import random
import os
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
import numpy as np
import random

class ShapeNetPointCloudDataset(Dataset):
    def __init__(self, root_dir, split='train', train_ratio=0.7, val_ratio=0.15, seed=42, augment=False, scale_range=(0.8, 1.2), translation_range=(-0.1, 0.1)):
        """
        Initializes the ShapeNetPointCloudDataset with data split options based on class.

        Args:
            root_dir (str): Path to the ShapeNetCore directory.
            split (str): Which data split to use ('train', 'val', or 'test').
            train_ratio (float): Proportion of data for training.
            val_ratio (float): Proportion of data for validation.
            seed (int): Random seed for reproducibility.
        """
        self.root_dir = root_dir
        self.split = split
        self.data = []  # Stores tuples of (file_path, class_label)
        self.augment = augment
        self.scale_range = scale_range
        self.translation_range = translation_range

        # Collect file paths organized by class
        class_files = self.get_class_files()

        # Generate class_name_to_index mapping
        self.class_name_to_index = {class_name: idx for idx, class_name in enumerate(class_files.keys())}
        
        # Split each class's files into train, val, and test sets
        self.create_splits(class_files, train_ratio, val_ratio, seed)

    def get_class_files(self):
        """
        Collects .ply file paths for each class in the ShapeNet directory.
        Returns:
            dict: A dictionary where keys are class labels, and values are lists of file paths.
        """
        class_files = {synset_to_class[synset]: [] for synset in synset_to_class}
        for synset_id, class_name in synset_to_class.items():
            synset_dir = os.path.join(self.root_dir, synset_id)
            for dirpath, _, files in os.walk(synset_dir):
                for file_name in files:
                    if file_name.endswith('.ply'):
                        class_files[class_name].append(os.path.join(dirpath, file_name))
        return class_files

    def create_splits(self, class_files, train_ratio, val_ratio, seed):
        """
        Splits the dataset into train, val, and test sets based on specified ratios.
        
        Args:
            class_files (dict): Dictionary of class labels to lists of file paths.
            train_ratio (float): Proportion of data for training.
            val_ratio (float): Proportion of data for validation.
            seed (int): Random seed for reproducibility.
        """
        random.seed(seed)
        for class_name, files in class_files.items():
            random.shuffle(files)
            total_files = len(files)
            train_len = int(total_files * train_ratio)
            val_len = int(total_files * val_ratio)

            if self.split == 'train':
                split_files = files[:train_len]
            elif self.split == 'val':
                split_files = files[train_len:train_len + val_len]
            elif self.split == 'test':
                split_files = files[train_len + val_len:]
            else:
                raise ValueError("Invalid split; choose from 'train', 'val', or 'test'.")

            self.data.extend([(file_path, class_name) for file_path in split_files])

    def apply_augmentations(self, points):
        """
        Applies random scaling and translation to the point cloud.
        
        Args:
            points (np.ndarray): Point cloud data, shape (num_points, 3).
        
        Returns:
            np.ndarray: Augmented point cloud.
        """
        # Random scaling
        scale_factor = np.random.uniform(*self.scale_range)
        points *= scale_factor

        # Random translation
        translation = np.random.uniform(*self.translation_range, size=(1, 3))
        points += translation

        return points

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Reads a .ply file, subsamples the point cloud, and returns the point cloud and its class label as a tensor.
        
        Args:
            idx (int): Index of the item in the dataset.
        
        Returns:
            tuple: A tuple containing:
                - points (torch.Tensor): Subsampled point cloud data of shape (num_points, 3).
                - class_index (int): Class index for the point cloud.
        """
        num_points = 1024  # Desired number of points after subsampling
        file_path, class_label = self.data[idx]
        ply_data = PlyData.read(file_path)
        points = np.vstack([
            np.array(ply_data['vertex'][axis]) for axis in ['x', 'y', 'z']
        ]).T  # Shape (N, 3)

        # Subsample points if necessary
        if points.shape[0] > num_points:
            indices = np.random.choice(points.shape[0], num_points, replace=False)
            points = points[indices]
        elif points.shape[0] < num_points:
            # Pad with zeros if there are fewer than num_points
            padding = np.zeros((num_points - points.shape[0], 3))
            points = np.vstack([points, padding])

        # Apply augmentations if enabled and split is 'train'
        if self.augment and self.split == 'train':
            points = self.apply_augmentations(points)

        # Convert class label to class index
        class_index = self.class_name_to_index[class_label]
        
        return torch.tensor(points, dtype=torch.float32), class_index



# Usage

# Create datasets for each split
root_dir = './ShapeNetCore.v2/ShapeNetCore.v2'
train_dataset = ShapeNetPointCloudDataset(root_dir, split='train', augment=True)
val_dataset = ShapeNetPointCloudDataset(root_dir, split='val', augment=False)
test_dataset = ShapeNetPointCloudDataset(root_dir, split='test', augment=False)

# Create DataLoaders for each split
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [24]:
# --- 4. Initialize Model, Optimizer, and Criterion ---
token_dim = 256
hidden_dim = 256
num_heads = 8
num_layers = 6
num_patches = 64
num_pts_per_patch = 32
num_channels = 3

model = PointRTDModel(
    input_dim=num_channels,
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    corruption_ratio=CORRUPTION_RATIO,
    noise_scale=0.9,
    num_patches=num_patches,
    num_pts_per_patch=num_pts_per_patch,
    finetune=False,
).to(DEVICE)

from torch.optim.lr_scheduler import CosineAnnealingLR

# Initialize the AdamW optimizer with weight decay
LEARNING_RATE = 0.0001
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05) # default lr is 0.001

# Set up a cosine annealing scheduler
# T_max is the number of epochs for the cosine cycle (EPOCHS)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0)

In [25]:
# # Load a model to resume progress
resume_epoch = 37
checkpoint_path = os.path.join(CHECKPOINT_DIR, f"pointrtd_epoch_{resume_epoch}_CR_{CORRUPTION_RATIO}.pth")

def load_pretraining_checkpoint(model, optimizer, scheduler, checkpoint_path):
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        
        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("Optimizer state loaded successfully.")
        
        # Load scheduler state
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print("Scheduler state loaded successfully.")
        
        # Resume epoch
        start_epoch = checkpoint.get('epoch', resume_epoch)+1
        print(f"Resuming training from epoch {start_epoch}.")
        
        return start_epoch
    else:
        print(f"Checkpoint not found at {checkpoint_path}")
        return 0  # Start from scratch if checkpoint is missing


        
resume_epoch = load_pretraining_checkpoint(model,optimizer, scheduler, checkpoint_path)

# Manually reset the learning rate
for param_group in optimizer.param_groups:
    param_group['lr'] = LEARNING_RATE  # Set your desired learning rate
    
# Adjust scheduler's learning rate
scheduler.base_lrs = [LEARNING_RATE]  # Match the reduced LR

Loading checkpoint from ./checkpoints/Pretrain_PointRTD/CR_0.6/pointrtd_epoch_37_CR_0.6.pth
Optimizer state loaded successfully.
Scheduler state loaded successfully.
Resuming training from epoch 37.


  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


In [None]:
import time
import os
from torch.utils.tensorboard import SummaryWriter

# Set up the directory for TensorBoard logs
LOG_DIR = "./tensorboard_logs"
os.makedirs(LOG_DIR, exist_ok=True)
writer = SummaryWriter(log_dir=LOG_DIR)



# --- 5. Training Loop ---
for epoch in range(resume_epoch,EPOCHS):
    model.train()
    total_reconstruction_loss = 0.0
    total_discriminator_loss = 0.0
    total_generator_loss = 0.0
    total_full_pcd_loss = 0.0

    for batch_idx, data in enumerate(train_loader):
        pointclouds, _ = data  # Assuming ShapeNetCore provides (pointcloud, label)
        pointclouds = pointclouds.to(DEVICE)

        # Forward pass through PointRTD model
        outputs = model(pointclouds)
        reconstructed_patches, patches, corrupted_mask, discriminator_output, original_tokens, cleaned_tokens, centers = outputs

        # Calculate losses
        reconstruction_loss, discriminator_loss, generator_loss, full_pcd_loss = model.get_loss(
            reconstructed_patches, patches, corrupted_mask, discriminator_output, original_tokens, cleaned_tokens, centers
        )

        # Combine losses
        total_loss = (
            1.0 * reconstruction_loss +
            1.0 * discriminator_loss +
            1.0 * generator_loss +
            1.0 * full_pcd_loss
        )

        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Accumulate losses for logging
        total_reconstruction_loss += reconstruction_loss.item()
        total_discriminator_loss += discriminator_loss.item()
        total_generator_loss += generator_loss.item()
        total_full_pcd_loss += full_pcd_loss.item()

        # Log losses to TensorBoard after each batch
        writer.add_scalar("Loss/Total_Loss", total_loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar("Loss/Reconstruction_Loss", reconstruction_loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar("Loss/Discriminator_Loss", discriminator_loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar("Loss/Generator_Loss", generator_loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar("Loss/Full_Pcd_Loss", full_pcd_loss.item(), epoch * len(train_loader) + batch_idx)

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx}/{len(train_loader)}], "
                f"Reconstruction Loss: {reconstruction_loss.item():.4f}, "
                f"Discriminator Loss: {discriminator_loss.item():.4f}, "
                f"Generator Loss: {generator_loss.item():.4f}, "
                f"Full PCD CD Loss: {full_pcd_loss.item():.4f}"
            )

    # Step the scheduler after each epoch
    scheduler.step()
    
    # Log learning rate after each epoch
    current_lr = scheduler.get_last_lr()[0]
    writer.add_scalar("Training/Learning_Rate", current_lr, epoch)

    # Log average epoch losses to TensorBoard
    avg_reconstruction_loss = total_reconstruction_loss / len(train_loader)
    avg_discriminator_loss = total_discriminator_loss / len(train_loader)
    avg_generator_loss = total_generator_loss / len(train_loader)
    avg_full_pcd_loss = total_full_pcd_loss / len(train_loader)

    writer.add_scalar("Loss/Epoch_Avg_Reconstruction_Loss", avg_reconstruction_loss, epoch)
    writer.add_scalar("Loss/Epoch_Avg_Discriminator_Loss", avg_discriminator_loss, epoch)
    writer.add_scalar("Loss/Epoch_Avg_Generator_Loss", avg_generator_loss, epoch)
    writer.add_scalar("Loss/Epoch_Avg_Full_Pcd_Loss", avg_full_pcd_loss, epoch)

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] - Avg Reconstruction Loss: {avg_reconstruction_loss:.4f}, "
        f"Avg Discriminator Loss: {avg_discriminator_loss:.4f}, "
        f"Avg Generator Loss: {avg_generator_loss:.4f}, "
        f"Avg Full Pcd Loss: {avg_full_pcd_loss:.4f}"
    )

## Uncomment this block to enable validation monitoring during training
#     # Validation step
#     model.eval()
#     val_reconstruction_loss = 0.0
#     val_discriminator_loss = 0.0
#     val_generator_loss = 0.0
#     with torch.no_grad():
#         for val_batch_idx, val_data in enumerate(val_loader):
#             val_pointclouds, _ = val_data
#             val_pointclouds = val_pointclouds.to(DEVICE)

#             # Forward pass on validation data
#             val_outputs = model(val_pointclouds)
#             val_reconstructed_patches, val_patches, val_corrupted_mask, val_discriminator_output, val_original_tokens, val_cleaned_tokens = val_outputs

#             # Calculate validation losses
#             val_recon_loss, val_disc_loss, val_gen_loss = model.get_loss(
#                 val_reconstructed_patches, val_patches, val_corrupted_mask, val_discriminator_output, val_original_tokens, val_cleaned_tokens
#             )

#             # Accumulate validation losses
#             val_reconstruction_loss += val_recon_loss.item()
#             val_discriminator_loss += val_disc_loss.item()
#             val_generator_loss += val_gen_loss.item()

#     # Calculate average validation losses and log to TensorBoard
#     avg_val_reconstruction_loss = val_reconstruction_loss / len(val_loader)
#     avg_val_discriminator_loss = val_discriminator_loss / len(val_loader)
#     avg_val_generator_loss = val_generator_loss / len(val_loader)

#     writer.add_scalar("Validation_Loss/Reconstruction_Loss", avg_val_reconstruction_loss, epoch)
#     writer.add_scalar("Validation_Loss/Discriminator_Loss", avg_val_discriminator_loss, epoch)
#     writer.add_scalar("Validation_Loss/Generator_Loss", avg_val_generator_loss, epoch)

#     print(
#         f"Validation - Epoch [{epoch+1}/{EPOCHS}] - Avg Reconstruction Loss: {avg_val_reconstruction_loss:.4f}, "
#         f"Avg Discriminator Loss: {avg_val_discriminator_loss:.4f}, "
#         f"Avg Generator Loss: {avg_val_generator_loss:.4f}"
#     )

    # Save checkpoint every 1 epochs
    if (epoch + 1) % 1 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"pointrtd_epoch_{epoch+1}_CR_{CORRUPTION_RATIO}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Close the TensorBoard writer
writer.close()

print("Training complete.")


Epoch [38/300], Batch [0/548], Reconstruction Loss: 0.0007, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [10/548], Reconstruction Loss: 0.0007, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [20/548], Reconstruction Loss: 0.0006, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [30/548], Reconstruction Loss: 0.0007, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [40/548], Reconstruction Loss: 0.0007, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [50/548], Reconstruction Loss: 0.0006, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [60/548], Reconstruction Loss: 0.0007, Discriminator Loss: 0.0000, Generator Loss: 0.0000, Full PCD CD Loss: 0.0004
Epoch [38/300], Batch [70/548], Rec

In [None]:
# --- 6. Save Reconstructed Point Clouds to H5 Files ---
# This code block saves the complete original and reconstructed point clouds to H5 files for visualization.

import os
import h5py
import torch
import numpy as np


def save_pointcloud_to_h5(pointcloud, file_path):
    """
    Save a point cloud to an H5 file.

    Args:
        pointcloud (np.ndarray or torch.Tensor): The point cloud data, shape (num_points, 3).
        file_path (str): Path to save the H5 file.
    """
    with h5py.File(file_path, 'w') as h5_file:
        h5_file.create_dataset('pointcloud', data=pointcloud.cpu().numpy() if isinstance(pointcloud, torch.Tensor) else pointcloud)
    print(f"Point cloud saved to {file_path}")


def save_complete_pointclouds(batch_idx, original_patches, reconstructed_patches, centers, output_dir):
    """
    Save the complete original and reconstructed point clouds to H5 files.

    Args:
        batch_idx (int): Index of the batch in the data loader.
        original_patches (torch.Tensor): Original patches, shape (batch_size, num_patches, num_points_per_patch, 3).
        reconstructed_patches (torch.Tensor): Reconstructed patches, shape (batch_size, num_patches, num_points_per_patch, 3).
        centers (torch.Tensor): Center coordinates of patches, shape (batch_size, num_patches, 3).
        output_dir (str): Directory to save the H5 files.
    """
    batch_size, num_patches, num_points_per_patch, _ = original_patches.shape

    for sample_idx in range(batch_size):
        # Reconstruct the complete original point cloud
        original_complete = original_patches[sample_idx].clone()
        reconstructed_complete = reconstructed_patches[sample_idx].clone()

        for patch_idx in range(num_patches):
            # Add the center back to unnormalize the patches
            original_complete[patch_idx, :, :3] += centers[sample_idx, patch_idx, :].unsqueeze(0)
            reconstructed_complete[patch_idx, :, :3] += centers[sample_idx, patch_idx, :].unsqueeze(0)

        # Flatten the patches into a single point cloud
        original_complete = original_complete.reshape(-1, 3)  # Shape: (num_patches * num_points_per_patch, 3)
        reconstructed_complete = reconstructed_complete.reshape(-1, 3)

        # Save the complete point clouds
        original_file_path = os.path.join(output_dir, f"original_pcd_batch{batch_idx}_sample{sample_idx}.h5")
        reconstructed_file_path = os.path.join(output_dir, f"reconstructed_pcd_batch{batch_idx}_sample{sample_idx}.h5")

        save_pointcloud_to_h5(original_complete, original_file_path)
        save_pointcloud_to_h5(reconstructed_complete, reconstructed_file_path)

        print(f"Saved complete original and reconstructed PCDs for batch {batch_idx}, sample {sample_idx}")

# --- Load the model checkpoint ---
def load_checkpoint(model, checkpoint_path, device):
    """
    Load a model checkpoint.

    Args:
        model (torch.nn.Module): The model instance.
        checkpoint_path (str): Path to the checkpoint file.
        device (torch.device): The device to load the model on.

    Returns:
        torch.nn.Module: The model with weights loaded from the checkpoint.
    """
    model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model_state_dict'])
    model.to(device)
    print(f"Checkpoint loaded from {checkpoint_path}")
    return model


# --- Example Usage ---
import torch
import os

# Example checkpoint path and output directory
checkpoint_path = CHECKPOINT_DIR + "/pointrtd_epoch_1_CR_0.6.pth"  # Change to your checkpoint path
output_dir = "./reconstructions"
os.makedirs(output_dir, exist_ok=True)

# Load model and set to evaluation mode
model = load_checkpoint(model, checkpoint_path, DEVICE)
model.train()

# Run the model on test data and save results
with torch.no_grad():
    for batch_idx, (original_pointclouds, _) in enumerate(train_loader):
        if batch_idx == 0:
            original_pointclouds = original_pointclouds.to(DEVICE)

            # Forward pass
            outputs = model(original_pointclouds)
            reconstructed_patches, patches, corrupted_mask, discriminator_output, original_tokens, cleaned_tokens, centers = outputs  # Extract reconstructed patches and original patches
        
            # Calculate losses
            reconstruction_loss, discriminator_loss, generator_loss, full_pcd_loss = model.get_loss(
                reconstructed_patches, patches, corrupted_mask, discriminator_output, original_tokens, cleaned_tokens, centers
            )
            
            print(
                f"Reconstruction Loss: {reconstruction_loss.item():.4f}, "
                f"Discriminator Loss: {discriminator_loss.item():.4f}, "
                f"Generator Loss: {generator_loss.item():.4f}, "
                f"Full PCD CD Loss: {full_pcd_loss.item():.4f}"
            )

            # Save complete original and reconstructed point clouds
            save_complete_pointclouds(batch_idx, patches, reconstructed_patches, centers, output_dir)

            print(f"Processed batch {batch_idx}.")


  model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model_state_dict'])


Checkpoint loaded from ./checkpoints/Pretrain_PointRTD/CR_0.6/pointrtd_epoch_1_CR_0.6.pth
Original tokens:  tensor([[-0.1229, -0.0713, -0.2982,  ..., -0.6140,  0.2125, -0.2669],
        [ 0.0579, -0.0886, -0.2643,  ..., -0.5590,  0.1778, -0.1969],
        [-0.0760, -0.0836, -0.2777,  ..., -0.4937,  0.1785, -0.2067],
        ...,
        [-0.0819, -0.0806, -0.3210,  ..., -0.5912,  0.2249, -0.2458],
        [-0.1197, -0.1165, -0.3180,  ..., -0.5493,  0.1892, -0.2650],
        [-0.0975, -0.0653, -0.2343,  ..., -0.5137,  0.2000, -0.2337]],
       device='cuda:0')
Noisy tokens:  tensor([[-0.1229, -0.0713, -0.2982,  ..., -0.6140,  0.2125, -0.2669],
        [ 0.0579, -0.0886, -0.2643,  ..., -0.5590,  0.1778, -0.1969],
        [-0.0760, -0.0836, -0.2777,  ..., -0.4937,  0.1785, -0.2067],
        ...,
        [-0.0819, -0.0806, -0.3210,  ..., -0.5912,  0.2249, -0.2458],
        [-0.1197, -0.1165, -0.3180,  ..., -0.5493,  0.1892, -0.2650],
        [-0.0975, -0.0653, -0.2343,  ..., -0.5137,  0.20