In [60]:
from src.data.dataloader import VesselCaptureDataset
from tqdm import tqdm
from torch.utils.data import DataLoader

In [61]:
data_dir = "Captured_images/"
dataset = VesselCaptureDataset(data_dir)

In [62]:
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [63]:
for batch in dataloader:
    color_images = batch["color_image"]
    depth_images = batch["depth_image"]
    vol_liquids = batch["vol_liquid"]
    vessel_names = batch["vessel_name"]
    vol_vessels = batch["vol_vessel"]

    # Do something with the data
    print(vessel_names)
    print(vol_liquids)

['seace', 'seace']
tensor([768, 123])
['seace', 'seace']
tensor([443, 345])
['trz1', 'seace']
tensor([100,  53])
['seace', 'seace']
tensor([234,  12])


In [64]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F


# Define the neural network architecture
class VesselNet(nn.Module):
    def __init__(self):
        super(VesselNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=5, padding=2
        )
        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=5, padding=2
        )

        self.fc1 = nn.Linear(in_features=128 * 4800, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=2)

    def forward(self, depth_image):
        x = depth_image.unsqueeze(1)  # add channel dimension
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, kernel_size=2)
        # print(x.shape)
        x = x.view(-1, 128 * 4800)
        # print(x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # print(x.shape)

        return x

In [65]:
# Define the training loop
def train(model, criterion, optimizer, train_loader):
    model.train()

    # Wrap train_loader with tqdm for a progress bar
    progress_bar = tqdm(train_loader, desc="Training")

    for i, data in enumerate(progress_bar):
        inputs = data["depth_image"]
        targets = torch.stack([data["vol_liquid"], data["vol_vessel"]], dim=1)
        targets = targets.float()

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Update progress bar
        progress_bar.set_postfix({"loss": loss.item()})

In [66]:
from sklearn.model_selection import train_test_split

# Load the dataset
dataset = VesselCaptureDataset(data_dir)

# Split the dataset into training and test data
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Set up the data loader and training parameters for the training data
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
train_size = len(train_data)

# Set up the data loader and training parameters for the test data
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
test_size = len(test_data)

In [67]:
model = VesselNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # print(f'Epoch {epoch + 1}/{num_epochs}')
    train(model, criterion, optimizer, train_loader)

# Save the trained model
torch.save(model.state_dict(), "vessel_net.pth")

Training: 100%|██████████| 3/3 [00:13<00:00,  4.38s/it, loss=8.03e+6]
Training: 100%|██████████| 3/3 [00:11<00:00,  3.85s/it, loss=3.28e+4]
Training: 100%|██████████| 3/3 [00:11<00:00,  3.86s/it, loss=1.42e+4]
Training: 100%|██████████| 3/3 [00:11<00:00,  3.95s/it, loss=8.03e+3]
Training: 100%|██████████| 3/3 [00:12<00:00,  4.05s/it, loss=5.16e+3]
Training: 100%|██████████| 3/3 [00:11<00:00,  3.91s/it, loss=2.19e+3]
Training: 100%|██████████| 3/3 [00:12<00:00,  4.10s/it, loss=2.05e+4]
Training: 100%|██████████| 3/3 [00:12<00:00,  4.09s/it, loss=678]    
Training: 100%|██████████| 3/3 [00:12<00:00,  4.17s/it, loss=3.42e+3]
Training: 100%|██████████| 3/3 [00:12<00:00,  4.30s/it, loss=1.05e+4]
