In [1]:
import os
import time
import numpy as np
from datetime import timedelta
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.optim import Adam

In [2]:

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: mps


In [3]:
class DefectDataset(Dataset):
    def __init__(self, root_dir, transform=None, mask_transform=None, split='train'):
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.mask_transform = mask_transform or transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()
        ])
        self.image_dir = os.path.join(self.root_dir, 'Img.after.melting')
        self.class_dirs = [f'Defect_class{cls}' for cls in [0, 5, 8, 9, 10, 11]]
        self.image_names = sorted(os.listdir(self.image_dir))
        
    def __len__(self):
        return len(self.image_names)
        
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Initialize empty mask
        mask = np.zeros((image.size[1], image.size[0]), dtype=np.uint8)
        
        # Combine all defect masks
        for i, class_dir in enumerate(self.class_dirs):
            mask_path = os.path.join(self.root_dir, class_dir, img_name)
            if os.path.exists(mask_path):
                class_mask = np.array(Image.open(mask_path))
                mask[class_mask > 0] = i + 1  # class 0 becomes 1, etc.
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        mask = self.mask_transform(Image.fromarray(mask)).squeeze(0).long()
        
        return image, mask

In [4]:
def create_model(num_classes):
    model = deeplabv3_resnet50(pretrained=True)
    model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    model.aux_classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
    return model.to(device)


In [5]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_val_loss = float('inf')
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        train_loss = 0.0
        
        for batch_idx, (images, masks) in enumerate(train_loader):
            batch_start = time.time()
            
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            
            # Print progress
            if batch_idx % 10 == 0:
                batch_time = time.time() - batch_start
                remaining = (len(train_loader) - batch_idx) * batch_time
                print(f'\rEpoch {epoch+1}/{num_epochs} | Batch {batch_idx}/{len(train_loader)} | '
                      f'ETA: {timedelta(seconds=int(remaining))}', end='')
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)['out']
                val_loss += criterion(outputs, masks).item() * images.size(0)
        
        # Calculate metrics
        train_loss /= len(train_loader.dataset)
        val_loss /= len(val_loader.dataset)
        epoch_time = time.time() - epoch_start
        total_remaining = (num_epochs - epoch - 1) * epoch_time
        
        print(f'\rEpoch {epoch+1}/{num_epochs} | '
              f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | '
              f'Time: {timedelta(seconds=int(epoch_time))} | '
              f'Total ETA: {timedelta(seconds=int(total_remaining))}')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_deeplabv3_defect.pth')
    
    print(f'\nTraining completed in {timedelta(seconds=int(time.time()-start_time))}')
    return model

In [6]:
batch_size = 4
learning_rate = 0.001
num_epochs = 20
num_classes = 7  # 6 defect classes + background

# Transforms
transform = transforms.Compose([
    transforms.Resize((384, 384)),  # Reduced from 512x512
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor()
])

In [7]:
train_dataset = DefectDataset(root_dir='/Users/sanjanahaldar/Library/CloudStorage/GoogleDrive-sanukadam721@gmail.com/My Drive/Info_Project/Defect_Detection/DataSets/Data.Splitting/After_Melting_Defect_Detection', 
                            transform=transform,
                            mask_transform=mask_transform,
                            split='train')
val_dataset = DefectDataset(root_dir='/Users/sanjanahaldar/Library/CloudStorage/GoogleDrive-sanukadam721@gmail.com/My Drive/Info_Project/Defect_Detection/DataSets/Data.Splitting/After_Melting_Defect_Detection',
                          transform=transform,
                          mask_transform=mask_transform,
                          split='val')


In [8]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [9]:
model = create_model(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)



In [None]:
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

Epoch 1/20 | Batch 0/626 | ETA: 3:30:50