In [None]:
import pytorch_lightning as pl
import torchmetrics as tm
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger

from datasets import ModelNet40, ScanObjectNN
from pooling import MA_Pooling, KNN_Pooling
from embedding_modules import Naive_Embedding,  Neighboorhood_Embedding
from Attention_modules.attention_blocks import Self_Attention_Block, MultiHead_Attention_Block, MultiHead_Attention_Block3

In [None]:
def create_datasets(dataset, batch_size):
    #Split into validation and test set
    test_set = dataset(1024,partition='test')
    validation_length = int(len(test_set)*0.5)
    test_length = len(test_set) - validation_length
    validation_set, test_set = random_split(test_set, [validation_length, test_length], generator=torch.Generator().manual_seed(42))
    train_dataloader = DataLoader(dataset(1024,partition='train', data_augmentation=True), num_workers=8,
                            batch_size=loader_batch, shuffle=True, drop_last=True)
    val_dataloader   = DataLoader(validation_set, num_workers=8,
                            batch_size=loader_batch, shuffle=False, drop_last=False)
    test_dataloader  = DataLoader(test_set, num_workers=8,
                            batch_size=loader_batch, shuffle=False, drop_last=False)

    return train_dataloader, val_dataloader, test_dataloader

In [None]:
#Based on the original PCT classifier refered to as PCT on the blog-post
class PCT_Classifier(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=0.01, input_features=3, attention_layers=4, encoder_channels=256, 
                 key_size=0.25, value_size=1, pooling ="both",linear_encoder_layer=1024, 
                 classification_layer_size = 256 , dropout=0.5, 
                 naive_embedding=False, k=32, sampling=0.25,  positional_embedding=True):
        
        super(PCT_Classifier, self).__init__()
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.input_features = input_features
        self.attention_layers = attention_layers
        self.encoder_channels = encoder_channels
        self.key_size = key_size
        self.value_size = value_size
        self.linear_encoder_layer = linear_encoder_layer
        self.pooling = pooling
        self.classification_layer_size = classification_layer_size
        self.dropout = dropout
        self.k = k
        self.sampling = sampling
        self.naive_embedding = naive_embedding
        self.positional_embedding = positional_embedding
             
        if naive_embedding == True:
            self.embedding = Naive_Embedding(self.input_features, self.encoder_channels)
        elif naive_embedding == False:
            self.embedding = Neighboorhood_Embedding(self.input_features, self.encoder_channels,
                                                 self.k ,self.sampling, self.positional_embedding)
        #why the different sizes this is bizzare
        self.attention_block = Self_Attention_Block(self.attention_layers, self.encoder_channels, self.key_size,
                                                            self.value_size)
        #Leaky Relu is here in the implementaiton by the authors
        self.conv_fuse = nn.Sequential(nn.Conv1d(self.encoder_channels*(self.attention_layers+1), self.linear_encoder_layer, kernel_size=1, bias=False),
                                        nn.BatchNorm1d(self.linear_encoder_layer),
                                        nn.LeakyReLU(0.02))
        
        assert self.pooling in {"both", "max", "avg"}, "Pooling must be either max, avg or both" 
        self.ma_pooling = MA_Pooling(self.pooling)
        if pooling == "both":
            self.classification_input = self.linear_encoder_layer*2
        else:
            self.classification_input = self.linear_encoder_layer
            
        self.classification_layers = nn.Sequential(
            nn.Linear(self.classification_input, self.classification_layer_size*2, bias=False),
            nn.BatchNorm1d(self.classification_layer_size*2),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size*2, self.classification_layer_size, bias=False),
            nn.BatchNorm1d(self.classification_layer_size),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size, self.num_classes)
        )
    
    def forward(self, x):
        x, coordinates = self.embedding(x)
        attention = self.attention_block(x, coordinates)
        x = torch.cat((x, attention), dim=1)
        x = self.conv_fuse(x)
        x = self.ma_pooling(x)
        x = self.classification_layers(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y) #.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("loss", loss, on_epoch=True, prog_bar=True)                         
        self.log("accuracy", accuracy, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y) #.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("val_accuracy", accuracy, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True) 
               
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        sheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 75)
        return {"optimizer": optimizer, "lr_scheduler": sheduler, "monitor": "val_loss"}

In [None]:
#Using MHA within a more traditional classifier refered to as MHA-PCT in the paper:
class PCT_Classifier_2(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=0.01, input_features=3, attention = "self", encoder_channels=256, 
                 attention_layers=4, heads=8, forward_expansion=4, attention_dropout=0.1,
                 key_size=0.25, value_size=1, pooling ="both", classification_layer_size = 256 , dropout=0.5, 
                 embedding="neighboorhood", k=32, sampling=0.25, positional_embedding=False):
        #add positional-embedding option
        #think about skip before the end
        #maybe a big layer could actually be good
        super(PCT_Classifier_2, self).__init__()
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.input_features = input_features
        self.attention = attention
        self.encoder_channels = encoder_channels
        self.attention_layers = attention_layers
        self.heads = heads
        self.forward_expansion = forward_expansion
        self.attention_dropout = attention_dropout
        self.key_size = key_size
        self.value_size = value_size
        self.pooling = pooling
        self.classification_layer_size = classification_layer_size
        self.dropout = dropout
        self.k = k
        self.sampling = sampling
        self.embedding = embedding
        self.positional_embedding = positional_embedding
             
        if embedding == "naive":
            self.embedding = Naive_Embedding(self.input_features, self.encoder_channels)
        elif embedding == "neighboorhood":
            self.embedding = Neighboorhood_Embedding(self.input_features, self.encoder_channels,
                                                 self.k ,self.sampling, self.positional_embedding)
      
        else:
            raise ValueError("embedding must be either naive, neighboorhood)
        
        if self.attention == "self":
            self.attention_block = Self_Attention_Block(self.attention_layers, self.encoder_channels, self.key_size,
                                                            self.value_size)
        elif self.attention =="multihead":
            self.attention_block = MultiHead_Attention_Block(self.attention_layers, self.heads, self.encoder_channels,
                                                             self.forward_expansion, self.attention_dropout)
        else:
            raise ValueError("attention must be either self or multihead")
        
    
        if self.pooling in {"both", "max", "avg"}:
            self.pooling = MA_Pooling(self.pooling)
        
        if pooling == "both":
            self.classification_input = self.encoder_channels*2*(self.attention_layers)
        else:
            self.classification_input = self.encoder_channels*(self.attention_layers)
        
    
        self.classification_layers = nn.Sequential(
            nn.Linear(self.classification_input, self.classification_layer_size*2, bias=False),
            nn.BatchNorm1d(self.classification_layer_size*2),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size*2, self.classification_layer_size, bias=False),
            nn.BatchNorm1d(self.classification_layer_size),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size, self.num_classes)
        )
    
    def forward(self, x):
        x, coordinates = self.embedding(x)
        if self.positional_embedding == True:
            x = self.attention_block(x, coordinates)
        else:
            x = self.attention_block(x)
        x = self.pooling(x)
        x = self.classification_layers(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)#.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("loss", loss, on_epoch=True, prog_bar=True)                         
        self.log("accuracy", accuracy, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        #Try Label smoothing doesn't work that well here
        loss = F.cross_entropy(logits, y)#.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("val_accuracy", accuracy, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True) 
               
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        #sheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 250)
        sheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        return {"optimizer": optimizer, "lr_scheduler": sheduler, "monitor": "val_loss"}

In [None]:
#Transformer with KNN-pooling and a relatively different structure to the PCT refered to as PCT_MHA 2
class PCT_Classifier_4(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=0.01, input_features=3, encoder_channels=256, 
                 attention_layers=4, heads=8, attention_dropout=0.1, classification_layer_size = 256 , dropout=0.5, 
                 embedding="neighboorhood", k=32, sampling=0.25, positional_embedding=False, pool_points=8):
        #add positional-embedding option
        #think about skip before the end
        #maybe a big layer could actually be good
        super(PCT_Classifier_4, self).__init__()
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.input_features = input_features
        self.encoder_channels = encoder_channels
        self.attention_layers = attention_layers
        self.heads = heads
        self.attention_dropout = attention_dropout
        self.classification_layer_size = classification_layer_size
        self.dropout = dropout
        self.k = k
        self.sampling = sampling
        self.embedding = embedding
     
             
        if embedding == "naive":
            self.embedding = Naive_Embedding(self.input_features, self.encoder_channels)
        elif embedding == "neighboorhood":
            self.embedding = Neighboorhood_Embedding(self.input_features, self.encoder_channels,
                                                 self.k ,self.sampling, False)
        
        self.attention_block = MultiHead_Attention_Block3(self.attention_layers, self.heads, self.encoder_channels, self.attention_dropout)

        self.conv_fuse = nn.Sequential(nn.Conv1d(self.encoder_channels, self.encoder_channels*2, kernel_size=1, bias=False),
                                            nn.BatchNorm1d(self.encoder_channels*2),
                                            nn.ReLU(),
                                            nn.Conv1d(self.encoder_channels*2, self.encoder_channels, kernel_size=1, bias=False),
                                            nn.BatchNorm1d(self.encoder_channels),
                                            nn.ReLU()
                                          )
        self.pooling = KNN_Pooling(pool_points, 32)
        
        self.classification_input =(pool_points+1)*self.encoder_channels
        
        
        self.classification_layers = nn.Sequential(
            nn.Linear(self.classification_input, self.classification_layer_size*2, bias=False),
            nn.BatchNorm1d(self.classification_layer_size*2),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size*2, self.classification_layer_size, bias=False),
            nn.BatchNorm1d(self.classification_layer_size),
            nn.ReLU(),
            nn.Dropout(p = self.dropout),
            nn.Linear(self.classification_layer_size, self.num_classes)
        )
    
    def forward(self, x):
        x, coordinates = self.embedding(x)
        x = self.attention_block(x)
        x =  self.conv_fuse(x)
        global_feature  = torch.max(x,2)[0]
        x = self.pooling(x, coordinates)
        x = torch.cat((global_feature, x), 1)
        x = self.classification_layers(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)#.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("loss", loss, on_epoch=True, prog_bar=True)                         
        self.log("accuracy", accuracy, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        #Try Label smoothing doesn't work that well here
        loss = F.cross_entropy(logits, y)#.squeeze(1)
        accuracy = tm.functional.accuracy(logits, y)
        self.log("val_accuracy", accuracy, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True) 
               
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        sheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 250)
        #heduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        return {"optimizer": optimizer, "lr_scheduler": sheduler, "monitor": "val_loss"}

In [None]:
pl.seed_everything(42, workers=True)
logger = TensorBoardLogger("tb_logs", name="my_model")
checkpoint_callback_accuracy = pl.callbacks.ModelCheckpoint(monitor="val_accuracy", mode="max",save_top_k=1)
checkpoint_callback_loss = pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min",save_top_k=1)
    
early_stopping_callback = pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)
    
    

trainer = pl.Trainer(overfit_batches= 0, gpus=-1, benchmark=True, max_epochs=250,
                     callbacks=[checkpoint_callback_accuracy, checkpoint_callback_loss,early_stopping_callback],
                    logger=logger)

In [None]:
#example of training with the orignal pct and Modelnet40
train_dataloader, val_dataloader, test_dataloader= create_datasets(Modelnet40, 32)
model = PCT_Classifier(40, dropout=0.5, 
                            pooling="max",
                            learning_rate=0.0001,
                            key_size=0.25,
                            k=32,
                            sampling=0.25,
                            encoder_channels=256)

In [None]:
#validate the results
trainer.validate(dataloaders=val_dataloader, ckpt_path=checkpoint_callback_accuracy.best_model_path)