In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        self.enc1 = self.double_conv(in_channels, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)
        self.enc4 = self.double_conv(256, 512)
        self.enc5 = self.double_conv(512, 1024)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.up4 = self.upconv(1024, 512)
        self.dec4 = self.double_conv(1024, 512)

        self.up3 = self.upconv(512, 256)
        self.dec3 = self.double_conv(512, 256)

        self.up2 = self.upconv(256, 128)
        self.dec2 = self.double_conv(256, 128)

        self.up1 = self.upconv(128, 64)
        self.dec1 = self.double_conv(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):

        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        x = self.up4(enc5)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)

        x = self.up3(x)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)

        x = self.up2(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)

        x = self.up1(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)

        return self.final_conv(x)



In [None]:
!pip install torchvision

In [4]:
import zipfile
import os
from PIL import Image
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Function to extract MSFD.zip and target the MSFD/1/ folder
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Extract all contents to the specified folder
        zip_ref.extractall(extract_to)
        print(f"Extracted ZIP file to {extract_to}")

# Define the UNet Model (same as previously defined)
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        self.enc1 = self.double_conv(in_channels, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)
        self.enc4 = self.double_conv(256, 512)
        self.enc5 = self.double_conv(512, 1024)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.up4 = self.upconv(1024, 512)
        self.dec4 = self.double_conv(1024, 512)

        self.up3 = self.upconv(512, 256)
        self.dec3 = self.double_conv(512, 256)

        self.up2 = self.upconv(256, 128)
        self.dec2 = self.double_conv(256, 128)

        self.up1 = self.upconv(128, 64)
        self.dec1 = self.double_conv(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        x = self.up4(enc5)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)

        x = self.up3(x)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)

        x = self.up2(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)

        x = self.up1(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)

        return self.final_conv(x)

# Custom Dataset for loading face images and corresponding masks
class FaceDataset(Dataset):
    def __init__(self, csv_file, face_crop_dir, face_segmentation_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.face_crop_dir = face_crop_dir
        self.face_segmentation_dir = face_segmentation_dir
        self.transform = transform
        
        # Filter out rows where with_mask is False
        self.data = self.data[self.data['with_mask'] == 1]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Get the image and mask paths
        img_name = row['filename']
        x1, y1, x2, y2 = row['X1'], row['Y1'], row['X2'], row['Y2']
        
        # Load image
        img_path = os.path.join(self.face_crop_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        
        # Crop image based on coordinates
        img = img.crop((x1, y1, x2, y2))
        
        # Load mask
        mask_path = os.path.join(self.face_segmentation_dir, img_name)
        mask = Image.open(mask_path).convert('L')
        mask = mask.crop((x1, y1, x2, y2))
        
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        
        return img, mask

# Function to prepare transforms for image and mask
def get_transforms():
    return transforms.Compose([
        transforms.Resize((256, 256)),  # Resize images to fixed size (256x256)
        transforms.ToTensor()  # Convert to tensor
    ])

# File paths
csv_file = 'dataset.csv'  # Path to your dataset CSV file
zip_path = 'MSFD.zip'  # Path to the MSFD ZIP file
extract_to = 'MSFD_extracted'  # Path where the ZIP will be extracted

# Extract MSFD.zip into the specified folder
extract_zip(zip_path, extract_to)

# Define paths for the extracted folders
face_crop_dir = os.path.join(extract_to, 'MSFD/1/face_crop')
face_segmentation_dir = os.path.join(extract_to, 'MSFD/1/face_crop_segmentation')

# Initialize dataset and dataloader
transform = get_transforms()
dataset = FaceDataset(csv_file, face_crop_dir, face_segmentation_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# Training Loop (for UNet)
def train_unet(model, dataloader, device):
    model.train()
    for batch_idx, (images, masks) in enumerate(dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        output = model(images)
        
        # Compute loss (using binary cross entropy)
        loss = nn.BCEWithLogitsLoss()(output, masks)
        
        # Backward pass and optimization
        loss.backward()
        
        # Assuming optimizer is defined outside the loop (e.g., Adam)
        # optimizer.step()
        
        print(f"Batch {batch_idx + 1}, Loss: {loss.item()}")

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, out_channels=1).to(device)

# Start training
train_unet(model, dataloader, device)


ModuleNotFoundError: No module named 'torchvision'