In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Custom Datasets
class TowerDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        
        # Iterate through the folder, loading image paths and offset tags
        for img_name in os.listdir(root_dir):
            if img_name.endswith('.png') or img_name.endswith('.jpg'):
                parts = img_name.split('_')
                dx = int(parts[2][1:]) if parts[2].startswith('P') else -int(parts[2][1:])
                dy = int(parts[3][1:]) if parts[3].startswith('P') else -int(parts[3][1:])
                self.data.append((img_name, dx, dy))
    
    
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        img_name, dx, dy = self.data[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor([dx, dy], dtype=torch.float32)
        return image, label

    
# Defining the Model
class TowerNet(nn.Module):
    def __init__(self):
        super(TowerNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),  # (150, 150) -> (150, 150)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (150, 150) -> (75, 75)
            
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # (75, 75) -> (75, 75)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (75, 75) -> (37, 37)
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # (37, 37) -> (37, 37)
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # (37, 37) -> (18, 18)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 18 * 18, 512),
            nn.ReLU(),
            nn.Linear(512, 2)  # Output Δx and Δy
        )
    
    
    def forward(self, x):
        x = self.features(x)
        x = self.fc(x)
        return x
    
    
# Train the model
def train_model(model, dataloader, criterion, optimizer, epochs=20):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward Propagation
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backpropagation and Optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader):.4f}")

        
# Validate the model
def validate_model(model, dataloader):
    model.eval()
    predictions, ground_truths = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predictions.append(outputs.cpu().numpy())
            ground_truths.append(labels.cpu().numpy())
    return np.vstack(predictions), np.vstack(ground_truths)



In [None]:
# Data preprocessing and loading
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
])

train_dataset = TowerDataset('./cropped/Train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize the model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TowerNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Start training
train_model(model, train_loader, criterion, optimizer)

# Preservation model
torch.save(model.state_dict(), 'tower_net.pth')
