In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
import os
import numpy as np
from PIL import Image

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

In [None]:

class SegmentationDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.img_transform = transforms.Compose([
            transforms.Resize((512, 1024)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((512, 1024), interpolation=Image.NEAREST)
        ])
        self.images = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
        print(f"Found {len(self.images)} images in {data_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.data_dir, img_name)
        mask_name = img_name.replace(".jpg","_mask.png")
        mask_path = os.path.join(self.data_dir, mask_name)

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        image = self.img_transform(image)
        mask = self.mask_transform(mask)
        mask = np.array(mask)
        mask = torch.from_numpy(mask)

        return image, mask.squeeze().long()

In [None]:
train_dataset = SegmentationDataset("../teethSegSet/trainSet/")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)

test_dataset = SegmentationDataset("../teethSegSet/testSet/")
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
import matplotlib.pyplot as plt

def visulize_segmentation( loader ):
    img, mask = next(iter(loader))

    img_np = img.cpu().numpy().squeeze()
    mask_np = mask.cpu().numpy().squeeze()

    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    img_np = std[:, None, None ] * img_np + mean[:, None, None]
    img_np = np.transpose(img_np, (1, 2, 0))

    print(f"Image shape: {img_np.shape}, Image dtype: {img_np.dtype}") 
    print(f"Mask shape: {mask_np.shape}, Mask dtype: {mask_np.dtype}")

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    ax1.imshow(img_np)
    ax1.set_title("Image")  
    ax1.axis('off')

    ax2.imshow(mask_np*255)
    ax2.set_title("Mask")
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

visulize_segmentation(test_loader)

In [None]:
model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
num_classes = 2
in_channels = model.classifier[-1].in_channels
aux_in_channels = model.aux_classifier[-1].in_channels

model.classifier[-1] = nn.Conv2d(
    in_channels, num_classes, kernel_size=(1, 1), stride=(1, 1)
)
model.aux_classifier[-1] = nn.Conv2d(
    aux_in_channels, num_classes, kernel_size=(1, 1), stride=(1, 1)
)

model = model.to(device)



In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 10000

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    running_loss = 0.0
    model.train()
    for inputs, masks in train_loader:
        inputs, masks = inputs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)['out']
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Training Loss: {epoch_loss:.4f}")

    model.eval()
    eval_loss = 0.0
    for inputs, masks in test_loader:
        inputs, masks = inputs.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(inputs)['out']
            loss = loss_fn(outputs, masks)
            eval_loss += loss.item() * inputs.size(0)
    print(f"Validation Loss: {eval_loss / len(test_loader.dataset):.4f}")
    torch.save(model.state_dict(), 'final_model.pth')

In [None]:


state_dict = torch.load('teethSegModel.pth', map_location=device)
model.load_state_dict(state_dict)



In [None]:
def visulize_segmentation( model, test_loader, device ) :
    model.eval()
    img, _ = next(iter(test_loader))
    img = img.to(device)

    with torch.no_grad():
        out = model(img)['out']
        pred = out.argmax(dim=1)
        pred = pred.cpu().numpy().squeeze()

    img_np = img.cpu().numpy().squeeze()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = std[:, None, None ] * img_np + mean[:, None, None]
    img_np = np.transpose(img_np, (1, 2, 0))

    print(f"Image shape: {img_np.shape}, Image dtype: {img_np.dtype}")

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    ax1.imshow(img_np)
    ax1.set_title("Image")
    ax1.axis('off')

    ax2.imshow(pred*255)
    ax2.set_title("Prediction")
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

visulize_segmentation(model, test_loader, device)