<a href="https://colab.research.google.com/github/Sivasankari1985/TLNet/blob/master/ObjpointNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
import torch
import os
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader



In [2]:
class SimplePointNet(nn.Module):
    def __init__(self, num_classes):
        super(SimplePointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, 1)  # Input channels: 3, Output channels: 64
        self.conv2 = nn.Conv1d(64, 128, 1)  # Input channels: 64, Output channels: 128
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.fc1(x.mean(dim=-1)))  # Global max pooling
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)




In [3]:
# Define your dataset class
class MyDataset(Dataset):
    def __init__(self, data_paths, labels):
        self.data_paths = data_paths
        self.labels = labels

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, index):
          V = []
          directory = os.listdir(self.data_paths)
          for filename in directory:
               if filename.endswith('.obj'):
                with open(os.path.join(self.data_paths,filename), 'r') as fid:
                  lines = fid.readlines()
               for line in lines:
                  if line.startswith('v '):
                      vertex = list(map(float, line.strip().split()[1:4]))
                      V.append(vertex)

          return V



In [4]:
# Set up paths and labels for training and validation datasets
drive.mount('/content/gdrive/')
directory = '/content/gdrive/MyDrive/3DPotteryDataset_v_1/3D Models/All Models/'
# List all .obj files in the directory
obj_files = os.listdir(directory)
# Visualize each .obj file using SimplePointNet

#obj_file_path = os.path.join(directory, filename)
train_data_paths = directory  # List of file paths for training data
train_labels = ['Abstract','Aryballos','Bowl']

# List of labels for training data
val_data_paths = directory  # List of file paths for validation data
val_labels = ['Abstract','Aryballos','Bowl'] # List of labels for validation data



Mounted at /content/gdrive/


In [5]:
# Create DataLoader for training and validation datasets
train_dataset = MyDataset(train_data_paths, train_labels)
print(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
print('train dataset')
print(train_dataset)
val_dataset = MyDataset(val_data_paths, val_labels)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False)



<__main__.MyDataset object at 0x7a0b566d06a0>
train dataset
<__main__.MyDataset object at 0x7a0b566d06a0>


In [6]:
# Instantiate SimplePointNet model
num_classes = 5  # Adjust as needed
model = SimplePointNet(num_classes)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()



In [None]:
# Training loop
num_epochs = 10
train_losses = []
val_accuracies = []

for epoch in range(10):
    model.train()
    total_loss = 0.0
    print(train_loader)
    for data, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)
    train_losses.append(train_loss)

    # Validation
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, labels in val_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            val_accuracy = correct / total
            val_accuracies.append(val_accuracy)

            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Training Loss: {train_loss:.4f}, '
                  f'Validation Accuracy: {val_accuracy:.4f}')



<torch.utils.data.dataloader.DataLoader object at 0x7a0b566d0e80>


In [None]:
  # Visualize the validation accuracy graph
            plt.plot(val_accuracies)
            plt.title('Validation Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.show()