In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset,DataLoader,random_split
import matplotlib.pyplot as plt
import datetime
import copy
import torchmetrics

In [None]:
print(torchmetrics.__version__)

In [None]:
import sys,os
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
  from google.colab import files
  file=files.upload() # upload the savced kaggle.json
  !mkdir /root/.kaggle # on colab you are use root
  !mv kaggle.json  /root/.kaggle
  !kaggle datasets download -d andradaolteanu/gtzan-dataset-music-genre-classification
  !pip install lightning pygments==2.6.1 >/dev/null 2>&1
  !pip install comet-ml >/dev/null 2>&1
  !unzip -q /content/gtzan-dataset-music-genre-classification.zip
  data_path="./Data/images_original"
elif os.environ.get('KAGGLE_KERNEL_RUN_TYPE') is not None:
  !pip install lightning > /dev/null 2>&1
  !pip install comet-ml >/dev/null 2>&1
  data_path="/kaggle/input/gtzan-dataset-music-genre-classification/Data/images_original"
else:
  data_path="./archive/Data/images_original"


In [None]:
import lightning as L
import comet_ml
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger,CometLogger
from lightning.pytorch.callbacks import ModelCheckpoint,Callback

In [None]:
comet_ml.init(project_name="music-classification")
#seed_everything(7979, workers=True)
seed_everything(123, workers=True)
batch_size=32
num_workers=0


In [None]:
class CustomDataset(Dataset):
    def __init__(self,subset,transform=None):
        self.subset=subset
        self.transform=transform
    def __getitem__(self,idx):
        x,y=self.subset[idx]
        if self.transform:
            x=self.transform(x)
        return x,y
    def __len__(self):
        return len(self.subset)

In [None]:
data_transforms = {
     'train':  transforms.Compose([ transforms.TrivialAugmentWide(),
                                   transforms.CenterCrop(224),transforms.ToTensor(),
                                   transforms.Normalize([0., 0., 0.], [1., 1., 1.])]),
  
    'val': transforms.Compose([transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0., 0., 0.], [1., 1., 1.])
    ])
}

In [None]:
dataset=torchvision.datasets.ImageFolder(data_path)
num_classes=len(dataset.classes)
datasets={}
datasets['train'],datasets['val']=random_split(dataset,lengths=[0.8,0.2])

In [None]:
image_datasets = {x: CustomDataset(datasets[x],data_transforms[x])
                  for x in ['train', 'val']}
dataloaders={x:DataLoader(image_datasets[x], batch_size=batch_size,
                shuffle=True if x=='train' else False,
                num_workers=num_workers) for x in ['train','val']
            }

In [None]:
class LResNet(L.LightningModule):
    def __init__(self,model,num_classes,lr=1e-3,momentum=0):
        super().__init__()
        self.model=model(num_classes=num_classes)
        self.lr=lr
        self.momentum=momentum
        self.num_classes=num_classes
        self.save_hyperparameters()
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=self.num_classes)
        self.valid_acc = torchmetrics.Accuracy(task='multiclass', num_classes=self.num_classes)

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self.model(x)
        loss=nn.functional.cross_entropy(y_hat,y)
        #acc=(y_hat.argmax(dim=1)==y).float().mean()
        self.train_acc(y_hat,y)     
        self.log("train_loss",loss,prog_bar=True,on_epoch=True,on_step=False)
        #self.log("train_acc",acc,prog_bar=True,on_epoch=True,on_step=False)
        self.log("train_acc",self.train_acc,prog_bar=True,on_epoch=True,on_step=False)
        return loss
    def validation_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self.model(x)
        loss=nn.functional.cross_entropy(y_hat,y)
        self.val_acc(y_hat,y)
        #acc=(y_hat.argmax(dim=1)==y).float().mean()
        self.log("val_loss",loss,prog_bar=True,on_epoch=True,on_step=False)
        #self.log("val_acc",acc,prog_bar=True,on_epoch=True,on_step=False)
        self.log("val_acc",self.val_acc,prog_bar=True,on_epoch=True,on_step=False)
        return loss
    def configure_optimizers(self):
        return optim.SGD(self.model.parameters(),
                        lr=self.lr,momentum=self.momentum)

In [None]:
class ResNet(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        self.resnet=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.resnet.fc=nn.Linear(self.resnet.fc.in_features,num_classes)
    def forward(self,x):
        return self.resnet(x)

In [13]:
max_epochs=50
learning_rate=0.001
momentum=0.9
Lmodule=LResNet(ResNet,num_classes=10,lr=learning_rate,momentum=momentum)


AssertionError: 

In [None]:
comet_logger=CometLogger(experiment_name='music_classification')
csv_logger=CSVLogger(save_dir="logs",name="music")
trainer=L.Trainer(accelerator='gpu',devices=1,max_epochs=max_epochs,deterministic=True,
                 logger=[comet_logger,csv_logger],
                 callbacks=ModelCheckpoint(dirpath="./checkpoints",monitor="val_acc",
                    mode="max",save_top_k=3)
                 )


In [None]:
trainer.fit(Lmodule,train_dataloaders=dataloaders['train'],
           val_dataloaders=dataloaders['val'])


In [None]:
comet_logger.experiment.end()