In [None]:
import json
from transform import data_transforms
from WasteDataset import WasteDataset
from WasteDataloader import WasteDataLoader
from utils.image import (
    rotate_ccw_45,
    rotate_ccw_90,
    rotate_ccw_135,
    rotate_180,
    rotate_cw_45,
    rotate_cw_90,
    rotate_cw_135
)

data_paths = json.load(open('./data/data_paths2.json', 'r', encoding='utf8'))
augment_funcs = [
    rotate_ccw_45,
    rotate_ccw_90,
    rotate_ccw_135,
    rotate_180,
    rotate_cw_45,
    rotate_cw_90,
    rotate_cw_135
]
batch_size = 64
datasets = {}
for key in data_paths.keys():
    datasets[key] = WasteDataset(data_paths[key],
                                 data_transforms[key],
                                 augment_functions=[])

waste_dataloader = WasteDataLoader(datasets=datasets,
                                   batch_size=batch_size)

In [None]:
waste_dataloader.get_dataset('test')[0]

In [None]:
from models.ResNetBaseClassifier import ResNetBaseClassifier
from models.DenseNetBaseClassifier import DenseNetBaseClassifier

class PretrainedModel:
    resnet = 'Resnet'
    densenet = 'Densenet'

pretrained_model = PretrainedModel.densenet

if pretrained_model == PretrainedModel.resnet:
    model = ResNetBaseClassifier(waste_dataloader.n_classes())
elif pretrained_model == PretrainedModel.densenet:
    model = DenseNetBaseClassifier(waste_dataloader.n_classes(), 0)

model.to('cuda')

In [None]:
# import torch

# model.load_state_dict(torch.load('./checkpoints/waste_weights_resnet_22112023_9.h5'))
# model.eval()

In [None]:
import torch

def compute_accuracy(y_pred, y_target):
    # print(y_target)
    _, y_pred_indices = y_pred.max(dim=1)
    # _, y_target_indices = y_target.max(dim=1)
    # print(y_pred_indices)
    # print(y_target_indices)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

In [None]:
import torch.optim as optim
import torch.nn as nn

# device = 'cuda'
# model = model.to(device)
# model.load_state_dict(torch.load('weights_40epoches.h5'))
optimizer = optim.Adam(model.get_fc_parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 mode='min', factor=0.8,
                                                 patience=1)
loss_func = nn.CrossEntropyLoss()

In [None]:
# Test 
from typing import Any, Set
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import pandas as pd
from sklearn.metrics import confusion_matrix


def classification_report_to_df(y_true: list, 
                                y_pred: list,
                                labels: list) -> pd.DataFrame:
    report = classification_report(y_true, y_pred, output_dict=True)
    labels.extend(['accuracy', 'micro avg', 'weighted avg'])
    report = pd.DataFrame(report).transpose()
    # report.insert(0, 'Label', labels)
    
    return report


def classification_report_to_csv(y_true: list, 
                                 y_pred: list, 
                                 csv_filepath: str='') -> None:
    report = classification_report(y_true, y_pred, output_dict=True)
    df = pd.DataFrame(report).transpose()
    df.to_csv(csv_filepath)
    
    
def conf_matrix(y_true: list, 
                y_pred: list,
                labels: list) -> list:
    
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    return cm

In [None]:
from tqdm import tqdm
import torch
import torch.nn as nn
import mlflow
from typing import Any, Set
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from mlflow_extend import mlflow as mlex
from WasteDataset import WasteDataset


def test(model: Any, 
         dataloader: DataLoader, 
         device: str='cuda') -> Set[list]:
    """"""
    pred_labels = []
    true_labels = []

    for inputs, labels in tqdm(dataloader):
        outputs = model(inputs.to(device))
        pred_probs_of_batch = F.softmax(outputs, dim=1).cpu().data.numpy()
        pred_labels_of_batch = np.argmax(pred_probs_of_batch, axis=1)
        
        pred_labels.extend(pred_labels_of_batch)
        true_labels.extend(labels)
    
    return true_labels, pred_labels


def log_metric(name, value, step):
    """Log a scalar value to both MLflow and TensorBoard"""
    mlflow.log_metric(name, value, step=step)
    
    
def log_param(key, value):
    """Log a scalar value to both MLflow and TensorBoard"""
    mlflow.log_param(key, value)
    
    
def log_model(model: Any, 
              model_name: str, 
              dataloader: DataLoader, 
              dataset: WasteDataset) -> None:
    """"""
    mlflow.pytorch.log_model(model, model_name)  
    
    y_true, y_pred = test(model, dataloader)
    df = classification_report_to_df(y_true, y_pred, list(dataset.get_non_empty_classes()))
    
    try:
        cm = conf_matrix(y_true, y_pred, list(dataset.get_non_empty_classes(True)))
        mlex.log_confusion_matrix(cm, f'{model_name}/confusion_matrix.png' )
    except:
        pass

    mlex.log_df(df, f'{model_name}/classification_report.csv')


def train_model(model: nn.Module,
                data_loaders: WasteDataLoader,
                loss_function: nn.CrossEntropyLoss,
                optimizer: optim.Adam,
                scheduler: optim.lr_scheduler.ReduceLROnPlateau,
                num_epochs: int=1,
                device: str='cuda',
                losses: dict={},
                accuracies: dict={},
                step_model_logging: int=5) -> nn.Module:
    """"""
    phases = ['train', 'validation']
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in phases:
            print('Phase:', phase)
            if phase == 'train':
                # scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(data_loaders.get_dataloader(phase)):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = loss_function(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # scheduler.step(loss)

                _, preds = torch.max(outputs, 1)
                running_loss += loss.detach() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / data_loaders.size_of_set(phase)
            epoch_acc = running_corrects.float() / data_loaders.size_of_set(phase)

            losses[phase].append(epoch_loss.item())
            accuracies[phase].append(epoch_acc.item())

            log_metric(f'{phase}_loss', epoch_loss, epoch)
            log_metric(f'{phase}_accuracy', float(epoch_acc), epoch)
            
            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss.item(),
                                                        epoch_acc.item()))

        if (epoch + 1) % step_model_logging == 0:
            log_model(model, 
                      f'model_{epoch + 1}', 
                      data_loaders.get_dataloader('test'),
                      data_loaders.get_dataset('test'))
        
    return model, losses, accuracies

In [None]:
phases = ['train', 'validation']
losses = {}
accuracies = {}
for phase in phases:
    losses[phase] = []
    accuracies[phase] = []

In [None]:
import mlflow.pytorch
from mlflow_extend import mlflow as mlex
from datetime import date
from time import time

# today = str(date.today())

mlflow.pytorch.autolog()

mlflow.set_tracking_uri('http://localhost:5000')
new_experiment = pretrained_model
# mlflow.create_experiment(new_experiment)
mlflow.set_experiment(new_experiment)

n_epoches = 60
mlflow.start_run(run_name=f'{pretrained_model}_test_2')
mlflow.log_param('Number of classes', waste_dataloader.n_classes())
mlflow.log_param('Batch size', batch_size)
mlflow.log_param('Dataset', './images/')
mlflow.log_param('Epoch', n_epoches)
scripted_model = torch.jit.script(model)
start = time()
scripted_model, losses, accuracies = train_model(model=model,
                                                 data_loaders=waste_dataloader,
                                                 loss_function=loss_func,
                                                 optimizer=optimizer,
                                                 scheduler=scheduler,
                                                 num_epochs=n_epoches,
                                                 losses=losses,
                                                 accuracies=accuracies)

# model_name = 'model'
# mlflow.pytorch.log_model(scripted_model, model_name)  # logging scripted model
# mlflow.pytorch.save_model(scripted_model, today + '/')
# model_path = mlflow.get_artifact_uri("model")
mlflow.log_param('Time', (time() - start) / 60)

mlflow.end_run()

In [None]:
mlflow.end_run()

In [None]:
torch.save(model.state_dict(), './checkpoints/waste_weights_resnet_23112023_data2.h5')

In [None]:
import json

with open('./checkpoints/waste_weights_resnet_23112023_data2_l.json', 'w', encoding='utf8') as f:
    json.dump(losses, f)


with open('./checkpoints/waste_weights_resnet_23112023_data2_a.json', 'w', encoding='utf8') as f:
    json.dump(accuracies, f)

In [None]:
model = ResNetBaseClassifier(waste_dataloader.n_classes() + 2)
model.to('cuda')
model.load_state_dict(torch.load('./checkpoints/waste_weights_resnet_22112023_9.h5'))
model.eval()
# model_trained = model

In [None]:
waste_dataloader.get_dataloader('test')

In [None]:
import torch.nn.functional as F
import numpy as np
device = 'cuda'

labels = waste_dataloader.get_classes().copy()
# labels
from PIL import Image
from tqdm import tqdm
# for sample in datasets['test']:
test_paths = datasets['test'].img_paths
test_labels = datasets['test'].img_labels
pred_labels = []
for path in tqdm(test_paths):
    img = Image.open(path)

    validation_batch = data_transforms['test'](img).to(device)
    validation_batch = torch.stack([validation_batch])
    pred_logits_tensor = model_trained(torch.Tensor(validation_batch))
    pred_probs = F.softmax(pred_logits_tensor, dim=1).cpu().data.numpy()
    pred_label = np.argmax(pred_probs[0])
    pred_labels.append(pred_label)

    # if pred_label != test_labels[i]:
    #     print()

In [None]:
len(test_paths)

In [None]:
labels.remove('37.Nhiệt kế thủy ngân')
labels.remove('38.Sơn, Vecni, Dung môi')

In [None]:
from sklearn.metrics import classification_report

print(classification_report(y_true, y_pred))

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# labels.remove('37.Nhiệt kế thủy ngân')
# labels.remove('38.Sơn, Vecni, Dung môi')

confusion_matrix = confusion_matrix(y_true, y_pred)

cm_display = ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = list(range(0, waste_dataloader.n_classes())))

cm_display.plot()
plt.show()

In [None]:
labels[9]

In [None]:
waste_dataloader.get_classes()

In [None]:
import matplotlib.pyplot as plt
labels = waste_dataloader.get_classes()
validation_img_paths = [
            "./images/0650.png",
            "./images/0697.png",
            "./images/0704.png",
            "./images/0655.png",
            "./images/0675.png",
            "./images/0681.png",
]
from PIL import Image
img_list = [Image.open(img_path) for img_path in validation_img_paths]

validation_batch = torch.stack([data_transforms['validation'](img).to(device)
                                for img in img_list])

pred_logits_tensor = model_trained(validation_batch)
pred_probs = F.softmax(pred_logits_tensor, dim=1).cpu().data.numpy()

fig, axs = plt.subplots(1, len(img_list), figsize=(20, 5))
for i, img in enumerate(img_list):
    ax = axs[i]
    ax.axis('off')
    # print(np.argmax(pred_probs[i]))
    ax.set_title(labels[np.argmax(pred_probs[i])])
    ax.imshow(img)