# CNN pattern model, proposal no.2.1

In [1]:
# Imports
import init
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import os
from src.loaders.CNNPatt2Dataset import CNNPatt2Dataset
from src.core.config import Conf
from src.models.train import train
from src.models.validate import RegionNetValidation
from src.loaders.spectrogram_cacher import SpectrogramCacher

validate_region = RegionNetValidation()

DATA_ROOT = "../kaggle-processed"
assert os.path.exists(DATA_ROOT)

## Config and w&b setup

In [2]:
config = Conf(
    loss_fn="BCE",
    n_epochs=100,
    n_reps_per_epoch=5,
    augmentation_std=0.3,
    empty_per_sound_multiplier=3,
    model_type="CNN with improved data augmentation",
)

wandb.login()
wandb_run = wandb.init(
    project="InzCNNRegionClassifier",
    notes="Region classification",
    config=config.to_dict()
)

[34m[1mwandb[0m: Currently logged in as: [33mimatynia[0m ([33mimatynia-inzynierka[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016673875849998392, max=1.0…

# Loading dataset

In [3]:
# setting up datasets
train_data = CNNPatt2Dataset(os.path.join(DATA_ROOT, "train"), config,
                             seed=10, region_classification=True, augment_gauss_max_std=config.augmentation_std)
valid_data = CNNPatt2Dataset(os.path.join(DATA_ROOT, "valid"), config,
                             seed=11, region_classification=True, augmentation=False)

print(f"Prepared {len(train_data)} training samples, {len(valid_data)} samples")
# set up data loader
train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=3)
valid_loader = DataLoader(valid_data, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=3)

2024-06-14 02:16:26.488 | INFO     | src.loaders.CNNPatt2Loader:__init__:58 - Loading annotations
2024-06-14 02:16:28.251 | INFO     | src.loaders.CNNPatt2Loader:_cache_items:169 - Caching dataset item descriptions
2024-06-14 02:16:28.711 | INFO     | src.loaders.CNNPatt2Loader:_cache_items:234 - Caching done: 5258 regions created
2024-06-14 02:16:28.712 | INFO     | src.loaders.CNNPatt2Loader:__init__:58 - Loading annotations
2024-06-14 02:16:29.129 | INFO     | src.loaders.CNNPatt2Loader:_cache_items:169 - Caching dataset item descriptions
2024-06-14 02:16:29.369 | INFO     | src.loaders.CNNPatt2Loader:_cache_items:234 - Caching done: 1150 regions created


Prepared 105160 training samples, 23000 samples


# Caching spectrograms

In [4]:
files_to_cache = train_data.get_all_files() + valid_data.get_all_files()
sample_cache = SpectrogramCacher.get_instance()
sample_cache.cache_all(files_to_cache, config.to_dict())

2024-06-14 02:16:29.377 | INFO     | src.loaders.spectrogram_cacher:cache_all:74 - Caching all 1606 files
2024-06-14 02:16:41.943 | INFO     | src.loaders.spectrogram_cacher:cache_all:80 - Caching completed


# MODEL

In [5]:
# STAGE 1: Classifying regions
torch.random.manual_seed(42)


class CNNRegionClassifier(nn.Module):
    OUTPUT_SHAPE = 2

    def __init__(self, config: Conf):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, config.cnn_1_filters, kernel_size=5, padding=2), 
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(config.cnn_1_filters, config.cnn_2_filters, kernel_size=5, padding=2),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(config.cnn_2_filters, config.cnn_3_filters, kernel_size=5, padding=2),
            nn.LeakyReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.flatten = nn.Flatten()
        self.fc_layers = nn.Sequential(
            nn.Linear(config.cnn_3_filters * (config.H // 8) * (config.W // 8), config.fc_size),  # Adjusted to the new output size
            nn.LeakyReLU(),
            nn.Dropout(config.fc_dropout),
            nn.Linear(config.fc_size, config.fc_size),
            nn.LeakyReLU(),
            nn.Dropout(config.fc_dropout),
            nn.Linear(config.fc_size, self.OUTPUT_SHAPE),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        x = self.fc_layers(x)
        return x

model = CNNRegionClassifier(config)

# TRAIN

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

loss_fn = nn.BCELoss()

In [7]:
train(
    config,
    train_loader,
    valid_loader,
    model,
    device,
    loss_fn,
    optimizer,
    validate_region,
    wandb_run
)

Epoch 1/100: 100%|██████████| 410/410 [00:55<00:00,  7.36it/s]
2024-06-14 02:17:37.944 | INFO     | src.models.train:train:24 - Epoch 1/100, Loss: 0.3310
2024-06-14 02:17:42.547 | INFO     | src.models.validate:__call__:96 - VALIDATION:     0/100 | BCE LOSS: 0.23752194, MSE LOSS: 0.07161916
2024-06-14 02:17:42.662 | INFO     | src.models.validate:__call__:108 - Better model saved
Epoch 2/100: 100%|██████████| 410/410 [00:55<00:00,  7.37it/s]
2024-06-14 02:18:38.341 | INFO     | src.models.train:train:24 - Epoch 2/100, Loss: 0.2295
Epoch 3/100: 100%|██████████| 410/410 [00:55<00:00,  7.36it/s]
2024-06-14 02:19:34.088 | INFO     | src.models.train:train:24 - Epoch 3/100, Loss: 0.2143
Epoch 4/100: 100%|██████████| 410/410 [00:55<00:00,  7.34it/s]
2024-06-14 02:20:29.952 | INFO     | src.models.train:train:24 - Epoch 4/100, Loss: 0.2017
Epoch 5/100: 100%|██████████| 410/410 [00:55<00:00,  7.34it/s]
2024-06-14 02:21:25.843 | INFO     | src.models.train:train:24 - Epoch 5/100, Loss: 0.1947
E

KeyboardInterrupt: 

In [8]:
wandb_run.finish()

0,1
epoch,▁▁▁▂▂▂▃▃▃▃▄▄▄▄▅▅▅▆▆▆▆▇▇▇███
train_loss,█▅▅▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
val_loss,▃▁▃▅█
val_mse,█▂▁▂▂

0,1
epoch,21.0
train_loss,0.08234
val_loss,0.29956
val_mse,0.06678
