In [5]:
import timm 
from timm.models import registry
from SegFunctions import *
import echonet
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
timm.list_models()
m = 'swin_tiny_patch4_window7_224'
registry.is_model(m)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:


# Define image transformations
transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.5], std=[0.5]),
    ToTensorV2(),
])

class EchoNetSegmentationDataset(echonet.datasets.Echo):
    def __getitem__(self, index):
        video, mask, _ = super().__getitem__(index)

        # Only return frames that have segmentation masks
        if mask.sum() == 0:
            return None  # Skip frames without masks

        # Convert to 3-channel (for Swin Transformer)
        video = video.repeat(3, 1, 1)  # (1, H, W) → (3, H, W)
        
        # Apply transformations
        transformed = transform(image=video.numpy(), mask=mask.numpy())
        image, mask = transformed["image"], transformed["mask"]
        
        return image, mask.float()


#file_path
file_path = 'dynamic/a4c-video-dir/'
# Load dataset
train_dataset = EchoNetSegmentationDataset(root=file_path, split="train", target_type="LargeFrame")
val_dataset = EchoNetSegmentationDataset(root=file_path, split="val", target_type="LargeFrame")

# Remove None values (frames without masks)
#train_dataset = [x for x in train_dataset if x is not None]
#val_dataset = [x for x in val_dataset if x is not None]

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)



In [7]:
import timm
import torch.nn as nn

class SwinUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, features_only=True)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(96, 1, kernel_size=1)  # Single-channel output (segmentation mask)
        )

    def forward(self, x):
        features = self.encoder(x)
        x = self.decoder(features[-1])
        return torch.sigmoid(x)  # Sigmoid for binary segmentation

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinUNet().to(device)


In [8]:
criterion = nn.BCEWithLogitsLoss()  # Binary segmentation loss
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)


In [None]:
scaler = torch.cuda.amp.GradScaler()  # ✅ Mixed precision training

# ✅ Training Loop with AMP for Faster GPU Training
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in train_loader:
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():  # ✅ Enable Mixed Precision
            outputs = model(images)
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss / len(train_loader)}")

# ✅ Evaluation on GPU
model.eval()
total_dice = 0

with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        preds = (outputs > 0.5).float()

        intersection = (preds * masks).sum()
        union = preds.sum() + masks.sum()
        dice = 2.0 * intersection / union
        total_dice += dice.item()

print(f"Validation Dice Score: {total_dice / len(val_loader):.4f}")

  scaler = torch.cuda.amp.GradScaler()  # ✅ Mixed precision training
