In [1]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import zipfile

In [2]:
zip_file_path = 'Spectrogram_Dataset.zip'
extract_dir = 'Spectrogram_Dataset'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# Paths and dimensions
input_dir = os.path.join('Spectrogram_Dataset', 'Spectrogram_Dataset', 'Input')
output_dir = os.path.join('Spectrogram_Dataset', 'Spectrogram_Dataset','Output')
height, width = 500, 500

In [5]:
# Dataset class
class UNetDataset(Dataset):
    def __init__(self, input_dir, output_dir, transform=None, target_transform=None, image_size=(256, 256)):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform
        self.target_transform = target_transform
        self.image_size = image_size

        # List all input files
        self.input_files = sorted(os.listdir(input_dir))

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        # Load input image
        input_file = self.input_files[idx]
        input_path = os.path.join(self.input_dir, input_file)
        input_image = Image.open(input_path).convert('L')  # Convert to grayscale

        # Load corresponding output folder
        track_name = input_file.split('_mix')[0]
        output_folder = os.path.join(self.output_dir, track_name)
        output_files = sorted(os.listdir(output_folder))

        # Load and stack output images
        output_images = []
        for output_file in output_files:
            output_path = os.path.join(output_folder, output_file)
            output_image = Image.open(output_path).convert('L')  # Convert to grayscale
            output_images.append(output_image)

        # Resize input and output images
        if self.image_size:
            input_image = input_image.resize(self.image_size)
            output_images = [img.resize(self.image_size) for img in output_images]

        # Apply transformations
        if self.transform:
            input_image = self.transform(input_image)
        else:
            input_image = transforms.ToTensor()(input_image)  # Default transform to tensor

        if self.target_transform:
            output_images = [self.target_transform(img) for img in output_images]
        else:
            output_images = [transforms.ToTensor()(img) for img in output_images]

        # Stack output images along the channel axis
        output_tensor = torch.cat(output_images, dim=0)

        return input_image, output_tensor

In [6]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

In [7]:
# Create dataset
dataset = UNetDataset(input_dir, output_dir, transform=transform)

# DataLoader for batching and shuffling
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [8]:
# Check
for inputs, outputs in dataloader:
    print(f"Input shape: {inputs.shape}, Output shape: {outputs.shape}")
    break

Input shape: torch.Size([3, 1, 256, 256]), Output shape: torch.Size([3, 5, 256, 256])


In [9]:
# Unet model

class UNET(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Encoder part of unet
        self.encoder1 = self.conv_block(in_channels, 32)
        self.encoder2 = self.conv_block(32, 64)
        self.encoder3 = self.conv_block(64, 128)
        self.encoder4 = self.conv_block(128, 256)
        self.encoder5 = self.conv_block(256, 512)

        # bottleneck layer
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder part of unet
        self.upsampling5 = self.upsampling_block(1024, 512)
        self.decoder5 = self.conv_block(1024, 512)
        self.upsampling4 = self.upsampling_block(512, 256)
        self.decoder4 = self.conv_block(512, 256)
        self.upsampling3 = self.upsampling_block(256, 128)
        self.decoder3 = self.conv_block(256, 128)
        self.upsampling2 = self.upsampling_block(128, 64)
        self.decoder2 = self.conv_block(128, 64)
        self.upsampling1 = self.upsampling_block(64, 32)
        self.decoder1 = self.conv_block(64, 32)


        # changing to desired number of channels
        self.output = nn.Conv2d(32, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        conv =  nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        return conv

    def forward(self, input):

        # Encoder part of unet
        encoder1 = self.encoder1(input)
        encoder2 = self.encoder2(nn.MaxPool2d(2)(encoder1))
        encoder3 = self.encoder3(nn.MaxPool2d(2)(encoder2))
        encoder4 = self.encoder4(nn.MaxPool2d(2)(encoder3))
        encoder5 = self.encoder5(nn.MaxPool2d(2)(encoder4))

        # bottleneck layer
        bottleneck = self.bottleneck(nn.MaxPool2d(2)(encoder5))

        # decoder part of unet
        decoder5 = self.upsampling5(bottleneck)
        decoder5 = torch.cat((decoder5, encoder5), dim=1)
        decoder5 = self.decoder5(decoder5)

        decoder4 = self.upsampling4(decoder5)
        decoder4 = torch.cat((decoder4, encoder4), dim=1)
        decoder4 = self.decoder4(decoder4)

        decoder3 = self.upsampling3(decoder4)
        decoder3 = torch.cat((decoder3, encoder3), dim=1)
        decoder3 = self.decoder3(decoder3)

        decoder2 = self.upsampling2(decoder3)
        decoder2 = torch.cat((decoder2, encoder2), dim=1)
        decoder2 = self.decoder2(decoder2)

        decoder1 = self.upsampling1(decoder2)
        decoder1 = torch.cat((decoder1, encoder1), dim=1)
        decoder1 = self.decoder1(decoder1)

        output = self.output(decoder1)
        return output

    def upsampling_block(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

In [10]:
in_channels, out_channels = 1, 5
model = UNET(in_channels, out_channels).to(device)
print(model)

UNET(
  (encoder1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): ReLU()
  )
  (encoder2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): ReLU()
  )
  (encoder3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): ReLU()
  )
  (encoder4): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): ReLU()
  )
  (encoder5): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(

In [11]:
class EnergyBasedLossFunction(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predictions, targets, epsilon=1e-6):
        """
        predictions: Tensor of shape (B, N, T), predicted signals
        targets: Tensor of shape (B, N, T), ground truth signals
        epsilon: Small constant to avoid division by zero
        """
        # Compute MSE loss for each source in each sample
        mse_loss = torch.mean((predictions - targets) ** 2, dim=-1) # Shape: (B, N)

        # Compute energy for each source in each sample
        energies = torch.sum(targets ** 2, dim=-1) # Shape: (B, N)

        # Compute weights for each source in each sample
        weights = 1.0 / (energies + epsilon) # Shape: (B, N)

        # Compute weighted loss for each source in each sample
        weighted_losses = weights * mse_loss # Shape: (B, N)

        # Average over all sources and batch samples
        total_loss = torch.mean(weighted_losses) # Scalar

        return total_loss

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)
loss_fn = EnergyBasedLossFunction()

In [13]:
# training
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [14]:
# testing
def test(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()
    test_loss, correct_preds = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct_preds += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct_preds /= size
    print(
        f"Test Error: \n Accuracy: {correct_preds*100:>7f}%, Avg loss: {test_loss:>8f}"
    )

In [15]:
epochs = 15
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}\n-------------------------")
    train(dataloader, model, loss_fn, optimizer)

Epoch 1
-------------------------
loss: 0.004181  [    0/    3]
Epoch 2
-------------------------
loss: 0.004130  [    0/    3]
Epoch 3
-------------------------
loss: 0.004087  [    0/    3]
Epoch 4
-------------------------
loss: 0.004050  [    0/    3]
Epoch 5
-------------------------
loss: 0.004021  [    0/    3]
Epoch 6
-------------------------
loss: 0.003996  [    0/    3]
Epoch 7
-------------------------
loss: 0.003973  [    0/    3]
Epoch 8
-------------------------
loss: 0.003952  [    0/    3]
Epoch 9
-------------------------
loss: 0.003933  [    0/    3]
Epoch 10
-------------------------
loss: 0.003915  [    0/    3]
Epoch 11
-------------------------
loss: 0.003899  [    0/    3]
Epoch 12
-------------------------
loss: 0.003883  [    0/    3]
Epoch 13
-------------------------
loss: 0.003868  [    0/    3]
Epoch 14
-------------------------
loss: 0.003853  [    0/    3]
Epoch 15
-------------------------
loss: 0.003838  [    0/    3]


In [16]:
torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
# Load the image
image_path = 'path_to_your_image.jpg'  # Path to the image you want to classify
image = Image.open(image_path)

# Define the image transformation (e.g., resize, to tensor, normalization)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 for models like ResNet
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Pre-trained ImageNet normalization
])

# Apply transformations to the image
input_tensor = transform(image)

# Add batch dimension (as models expect a batch, even for one image)
input_tensor = input_tensor.unsqueeze(0)
