In [None]:
import torch
import torch.nn as nn
import librosa
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import librosa.display
import os
import io




In [None]:

NUM_CLASSES = 4  

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class GenreResNet(nn.Module):
    def __init__(self, block, layers, num_classes=NUM_CLASSES):
        super(GenreResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def resnet18_genre_classifier():
    """Constructs a ResNet-18 model for genre classification."""
    return GenreResNet(BasicBlock, [2, 2, 2, 2])


In [None]:



def preprocess_audio(audio_path, duration=30):
    """
    1. Loads an audio file.
    2. Trims it to the specified duration (in seconds).
    3. Converts it into a Mel spectrogram.
    4. Applies the necessary transformations for the model.
    """
    try:

        y, sr = librosa.load(audio_path, sr=None, mono=True, duration=duration)


        if len(y) < sr * duration:
            y = np.pad(y, (0, sr * duration - len(y)), 'constant')

        S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
        S_DB = librosa.power_to_db(S, ref=np.max)


        fig = plt.figure(figsize=(2.56, 2.56), dpi=100) 
        librosa.display.specshow(S_DB, sr=sr, cmap='magma')
        plt.axis('off')
        plt.tight_layout(pad=0)
        

        fig.canvas.draw()
        img_data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img = img_data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        
        pil_img = Image.fromarray(img)


        transform = transforms.Compose([
            
            transforms.ToTensor()

        ])
        
        image_tensor = transform(pil_img)
        

        return image_tensor.unsqueeze(0)

    except Exception as e:
        print(f"Error processing audio file: {e}")
        return None

def predict_genre(model, audio_tensor, class_names):
    """
    Takes a preprocessed audio tensor and returns the predicted genre.
    """
    model.eval()
    with torch.no_grad():
        output = model(audio_tensor)
        _, predicted_idx = torch.max(output, 1)
        return class_names[predicted_idx.item()]




In [None]:

if __name__ == '__main__':

    MODEL_PATH = '/Users/abynaya/Documents/best_genre_resnet18.pth' 
    AUDIO_FILE = '/Users/abynaya/Downloads/Green Day - Last Night on Earth.mp3'

In [None]:

    GENRE_CLASSES = ['classical', 'hiphop', 'pop', 'rock'] 

    # --- 1. Load Model ---
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Instantiate the model
    model = resnet18_genre_classifier()
    
    # Load the saved state dictionary
    # Use map_location to load the model on CPU if it was trained on GPU and you are now on CPU
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file not found at '{MODEL_PATH}'")
    elif not os.path.exists(AUDIO_FILE):
        print(f"Error: Audio file not found at '{AUDIO_FILE}'")
    else:
        try:
            model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
            model.to(device)
            print("Model loaded successfully.")

            # --- 2. Preprocess Audio ---
            print("Processing audio file...")
            audio_tensor = preprocess_audio(AUDIO_FILE)

            if audio_tensor is not None:
                audio_tensor = audio_tensor.to(device)

                # --- 3. Predict Genre ---
                predicted_genre = predict_genre(model, audio_tensor, GENRE_CLASSES)
                print("\n" + "="*30)
                print(f"The predicted genre for '{os.path.basename(AUDIO_FILE)}' is: {predicted_genre.upper()}")
                print("="*30)

        except Exception as e:
            print(f"An error occurred during model loading or prediction: {e}")
            print("Please ensure the model architecture in this script matches the one used for training.")