Semantic Segmentation 
============

In this exercise you are going to work on a computer vision task called semantic segmentation. In comparison to image classification the goal is not to classify an entire image but each of its pixels separately. This implies that the  output of the network is not a single scalar but a segmentation with the same shape as the input image. Think about why you should rather use convolutional than fully-connected layers for this task!

<img src="https://camo.githubusercontent.com/d10b897e15344334e449104a824aff6c29125dc2/687474703a2f2f63616c76696e2e696e662e65642e61632e756b2f77702d636f6e74656e742f75706c6f6164732f646174612f636f636f7374756666646174617365742f636f636f73747566662d6578616d706c65732e706e67">

## (Optional) Mount folder in Colab

Uncomment the following cell to mount your gdrive if you are using the notebook in google colab:

In [2]:
# Use the following lines if you want to use Google Colab
# We presume you created a folder "i2dl" within your main drive folder, and put the exercise there.
# NOTE: terminate all other colab sessions that use GPU!
# NOTE 2: Make sure the correct exercise folder (e.g exercise_10) is given.

# """
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/i2dl/exercise_10'

# This will mount your google drive under 'MyDrive'
drive.mount('/content/gdrive', force_remount=True)
# In order to access the files in this notebook we have to navigate to the correct folder
os.chdir(gdrive_path)
# Check manually if all files are present
print(sorted(os.listdir()))
# """

Mounted at /content/gdrive
['.tmp', '1_segmentation_nn.ipynb', 'exercise_code', 'images', 'lightning_logs', 'log', 'logs', 'models']


### Set up PyTorch environment in colab
- (OPTIONAL) Enable GPU via Runtime Change runtime type --> GPU -->
- Uncomment the following cell if you are using the notebook in google colab:

In [None]:
# Optional: install correct libraries in google colab
!python -m pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
!python -m pip install tensorboard==2.9.0
!python -m pip install pytorch-lightning==1.6.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.11.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl (1637.0 MB)
[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.6/1.6 GB[0m [31m197.1 MB/s[0m eta [36m0:00:01[0mtcmalloc: large alloc 1636999168 bytes == 0x3eb8000 @  0x7feced3bf680 0x7feced3e0824 0x5b3128 0x5bbc90 0x5f714c 0x64d800 0x527022 0x504866 0x56bbe1 0x569d8a 0x5f60c3 0x56bbe1 0x569d8a 0x5f60c3 0x56bbe1 0x569d8a 0x5f60c3 0x56bbe1 0x569d8a 0x5f60c3 0x56bbe1 0x569d8a 0x5f60c3 0x56bbe1 0x5f5ee6 0x56bbe1 0x569d8a 0x5f60c3 0x56cc92 0x569d8a 0x5f60c3
[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.6/1.6 GB[0m [31m200.1 MB/s[0m eta [36m0:00:01[0mtcmalloc: large alloc 2046255104 bytes == 0x657e2000 @  0x7feced3bf680 0x7feced3dfda2 

In [None]:
import sys

# For google colab
!python -m pip install pytorch-lightning==1.6.0 > /dev/null

# For anaconda/regular python
# !{sys.executable} -m pip install pytorch-lightning==1.6.0 > /dev/null
# 1. Preparation

## Imports

In [None]:
!pip install -U torchtext==0.8.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtext==0.8.0
  Downloading torchtext-0.8.0-cp38-cp38-manylinux1_x86_64.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchtext
  Attempting uninstall: torchtext
    Found existing installation: torchtext 0.14.1
    Uninstalling torchtext-0.14.1:
      Successfully uninstalled torchtext-0.14.1
Successfully installed torchtext-0.8.0


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch

from exercise_code.data.segmentation_dataset import SegmentationData, label_img_to_rgb
from exercise_code.data.download_utils import download_dataset
from exercise_code.util import visualizer, save_model
from exercise_code.util.Util import checkSize, checkParams, test
from exercise_code.networks.segmentation_nn import SegmentationNN, DummySegmentationModel
from exercise_code.tests import test_seg_nn
#set up default cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

OSError: ignored

## Setup TensorBoard
In exercise 07 you've already learned how to use TensorBoard. Let's use it again to make the debugging of our network and training process more convenient! Throughout this notebook, feel free to add further logs or visualizations to your TensorBoard!

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs --port 6006

## Load and Visualize Data

#### MSRC-v2 Segmentation Dataset

The MSRC v2 dataset is an extension of the MSRC v1 dataset from Microsoft Research in Cambridge. It contains *591* images and *23* object classes with accurate pixel-wise labeled images. 



The image ids are stored in the txt file `train.txt`, `val.txt`, `test.txt`. The dataloader will read the image id in the txt file and fetch the corresponding input and target images from the image folder. 
<img src='images/input_target.png'/>



As you can see in `exercise_code/data/segmentation_dataset.py`, each segmentation label has its corresponding RGB value stored in the `SEG_LABELS_LIST`. The label `void` means `unlabeled`, and it is displayed as black `"rgb_values": [0, 0, 0]` in the target image. The target image pixels will be labeled based on its color using `SEG_LABELS_LIST`.

```python
                SEG_LABELS_LIST = [
                {"id": -1, "name": "void",       "rgb_values": [0,   0,    0]},
                {"id": 0,  "name": "building",   "rgb_values": [128, 0,    0]},
                {"id": 1,  "name": "grass",      "rgb_values": [0,   128,  0]},
                {"id": 2,  "name": "tree",       "rgb_values": [128, 128,  0]},
                {"id": 3,  "name": "cow",        "rgb_values": [0,   0,    128]},
                {"id": 4,  "name": "horse",      "rgb_values": [128, 0,    128]},
                {"id": 5,  "name": "sheep",      "rgb_values": [0,   128,  128]},
                ...]    
```

<div class="alert alert-block alert-warning">
    <h3>Note: The label <code>void</code></h3>
    <p>Pixels with the label <code>void</code> should neither be considered in your loss nor in the accuracy of your segmentation. See implementation for details.</p>
</div>

In [None]:
download_url = 'https://i2dl.vc.in.tum.de/static/data/segmentation_data.zip'
i2dl_exercises_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root = os.path.join(i2dl_exercises_path, 'datasets','segmentation')


download_dataset(
    url=download_url,
    data_dir=data_root,
    dataset_zip_name='segmentation_data.zip',
    force_download=False,
)

train_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/train.txt')
val_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/val.txt')
test_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/test.txt')

If you want to implement data augmentation methods, make yourself familiar with the segmentation dataset and how we implemented the `SegmentationData` class in `exercise_code/data/segmentation_dataset.py`. Furthermore, you can check the original label description in `datasets/segmentation/segmentation_data/info.html`.

For now, let's look at a few samples of our training set:

In [None]:
print("Train size: %i" % len(train_data))
print("Validation size: %i" % len(val_data))
print("Img size: ", train_data[0][0].size())
print("Segmentation size: ", train_data[0][1].size())

num_example_imgs = 4
plt.figure(figsize=(10, 5 * num_example_imgs))
for i, (img, target) in enumerate(train_data[:num_example_imgs]):
    # img
    plt.subplot(num_example_imgs, 2, i * 2 + 1)
    plt.imshow(img.numpy().transpose(1,2,0))
    plt.axis('off')
    if i == 0:
        plt.title("Input image")
    
    # target
    
    plt.subplot(num_example_imgs, 2, i * 2 + 2)
    plt.imshow(label_img_to_rgb(target.numpy()))
    plt.axis('off')
    if i == 0:
        plt.title("Target image")
plt.show()

We can already see that the dataset is quite small in comparison to our previous datasets, e.g., for CIFAR10 we had ten thousands of images while we only have 276 training images in this case. In addition, the task is much more difficult than a "simple 10 class classification", as we have to assign a label to each pixel! What's more, the images are much bigger as we are now considering images of size 240x240 instead of 32x32. 

That means that you shouldn't expect our networks to perform very well, so don't be too disappointed.

# 2. Semantic Segmentation 

## Dummy Model

In `exercise_code/networks/segmentation_nn.py` we define a naive `DummySegmentationModel`, which always predicts the scores of segmentation labels of the first image. Let's try it on a few images and visualize the outputs using the `visualizer` we provide. The `visualizer` takes in the model and dataset, and visualizes the first four (Input, Target, Prediction) pairs. 

In [None]:
dummy_model = DummySegmentationModel(target_image=train_data[0][1])

# Visualization function
visualizer(dummy_model, train_data)

You can use the visualizer function in your training scenario to print out your model predictions on a regular basis.

## Loss and Metrics
The loss function for the task of image segmentation is a pixel-wise cross entropy loss. This loss examines each pixel individually, comparing the class predictions (depth-wise pixel vector) to our one-hot encoded target vector. 


In [None]:
from PIL import Image

Image.open('/content/gdrive/MyDrive/i2dl/exercise_10/images/loss_img.png')


https://www.jeremyjordan.me/semantic-segmentation/

Up until now we only used the default loss function (`nn.CrossEntropyLoss`) in our solvers. However, In order to ignore the `unlabeled` pixels for the computation of our loss, we have to use a customized version of the loss for the initialization of our segmentation solver. The `ignore_index` argument of the loss can be used to filter the `unlabeled` pixels and computes the loss only over remaining pixels.


In [None]:
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')

for (inputs, targets) in train_data[0:4]:
    inputs, targets = inputs, targets
    outputs = dummy_model(inputs.unsqueeze(0))
    losses = loss_func(outputs, targets.unsqueeze(0))
    print(losses)

<div class="alert alert-warning">
    <h3>Note: Non-zero loss for the first sample</h3> 
    <p>The output of our dummy model is one-hot-coded tensor. Since there is <b>softmax</b> function in the <b>nn.CrossEntropyLoss</b> function, the loss is:</p>
</div>

$$loss(x, class) = - \log \left( \frac{\exp(x[class])}{\Sigma_j \exp (x[j])} \right) = −x[class]+\log \left( \Sigma_j \exp(x[j]) \right)$$ 

and the loss will not be zero.

i.e. for $x=[0, 0, 0, 1], class=3$

the loss:

$$loss(x,class) = -1 +\log(\exp(0)+\exp(0)+\exp(0)+\exp(1)) = 0.7437$$

To obtain an evaluation accuracy, we can simply compute the average per pixel accuracy of our network for a given image. We will use the following function:

In [None]:
def evaluate_model(model, dataloader):
    test_scores = []
    model.eval()
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model.forward(inputs)
        _, preds = torch.max(outputs, 1)
        targets_mask = targets >= 0
        test_scores.append(np.mean((preds.cpu() == targets.cpu())[targets_mask].numpy()))

    return np.mean(test_scores)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=1,shuffle=False,num_workers=0)
# print(evaluate_model(dummy_model, test_loader))

You will see reasonably high numbers as your accuracy when you do the training later. The reason behind that is the fact that most output pixels are of a single class and the network can just overfit to common classes such as "grass".

## Step 1: Design your own model

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p>Implement your network architecture in <code>exercise_code/networks/segmentation_nn.py</code>. In this task, you will use pytorch to setup your model.
    </p>
</div>

To compensate for the dimension reduction of a typical convolution layer, you should probably include either a single `nn.Upsample` layer, use a combination of upsampling layers as well as convolutions or even transposed convolutions near the end of your network to get back to the target image shape.

This file is mostly empty but contains the expected class name, and the methods that your model needs to implement (only `forward()` basically). 
The only rules your model design has to follow are:
* Inherit from `torch.nn.Module` or `pytorch_lightning.LightningModule`
* Perform the forward pass in `forward()`. Input dimension is (N, C, H, W) and output dimension is (N, num_classes, H, W)
* Have less than 5 million parameters
* Have a model size of less than 50MB after saving

Furthermore, you need to pass all your hyperparameters to the model in a single dict `hparams`.

<div class="alert alert-warning">
    <h3>Note: Transfer learning</h3>
    <p>In this exercise, we encourage you to do transfer learning as we learned in exercise 8, since this will boost your model performance and save training time. You can import pretrained models from torchvision in your model and use its feature extractor (e.g. <code>alexnet.features</code>) to get the image feature. Feel free to choose more advanced pretrained model like ResNet, MobileNet for your architecture design.</p>       
</div>

See [here](https://pytorch.org/vision/stable/models.html) for more info of the torchvison pretrained models.


In [None]:
hparams = {
    "batch_size" : 47,
    "epochs" : 100,
    
    # optimizer
    "learning_rate" : 3e-4,
    "weight_decay": 1e-5,
    'momentum': 0.9
    # TODO: if you have any model arguments/hparams, define them here and read them from this dict inside SegmentationNN class
}

In [None]:
# https://medium.com/the-owl/extracting-features-from-an-intermediate-layer-of-a-pretrained-model-in-pytorch-c00589bda32b
# https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb#scrollTo=CRIOwoQvBKPm
# https://stackoverflow.com/questions/52235520/how-to-use-pnasnet5-as-encoder-in-unet-in-pytorch
# import torchvision.models as models

# rn18 = models.resnet18(pretrained=True)
# children_counter = 0
# for n,c in rn18.named_children():
#     print("Children Counter: ",children_counter," Layer Name: ",n,)
#     children_counter+=1
# rn18._modules

In [None]:
# from exercise_code.networks.segmentation_nn import DummySegmentationModel, DoubleConv, InConv, Down, Up, OutConv, Unet
from exercise_code.networks.segmentation_nn import SegmentationNN

model = SegmentationNN()
test_seg_nn(model)

In [None]:
import torchvision.models as models
alexnet = models.alexnet(pretrained=True).features
alexnet

## Step 2: Train your own model

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p> In addition to the network itself, you will also need to write the code for the model training. You can use PyTorch Lightning for that, or you can also write it yourself in standard PyTorch.
    </p>
</div>

--- 랜덤서치
https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html

-- 러닝레이트 플라토 
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from torch.optim.lr_scheduler import ReduceLROnPlateau
import pytorch_lightning as pl
from torchvision import transforms
import torch.optim as optim 
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


########################################################################
# TODO - Train Your Model                                              #
########################################################################

# dataset 
transfomr = transforms.Compose(transforms.ToTensor())

train_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/train.txt')
val_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/val.txt')
test_data = SegmentationData(image_paths_file=f'{data_root}/segmentation_data/test.txt')

# data loader 
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=hparams['batch_size'],
                                           shuffle=False,
                                           num_workers=2)

val_loader = torch.utils.data.DataLoader(val_data, 
                                         batch_size=hparams['batch_size'],
                                         shuffle=False, 
                                         num_workers=2)
 
# moveing the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

# optimizer and loss function
optimizer = optim.Adam(model.parameters(),
                       lr = hparams['learning_rate'],
                       weight_decay = hparams['weight_decay'])

loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')

# epoch 
epochs = hparams['epochs']

# initializing thel ist for storing the loss and accuracy
train_loss_history = []
train_acc_history = []

# get the number of batches 
def get_num_batches(dataloader, data):
  return int(len(batch) // dataloader.batch_size)

# TODO : Learning Scheduler : https://velog.io/@idj7183/Competition%EC%97%90%EC%84%9C-%EC%83%88%EB%A1%9C-%EB%B0%B0%EC%9A%B4-%EA%B2%83
def get_scheduler(optimizer, step_size, gamma):
  return torch.optim.lr_scheduler.StepLR(optimizer, step_size = step_size, gamma = gamma)

# create progress bar 
def create_tqdm_bar(iterable, desc):
  return tqdm(enumerate(iterable), total=len(iterable), ncols = 150, desc = desc, leave = False)

# Training 
def train_model(model, train_loader, val_loader, loss_func, optimizer, writer, epochs = 50):
  optimizer = optimizer
  # scheduler = get_scheduler(optimizer = optimizer, step_size = epochs * len(train_loader) / 20, gamma = 0.6)
  # scheduler = get_scheduler(optimizer = optimizer, step_size = 50, gamma = 0.7)
  scheduler = ReduceLROnPlateau(optimizer, 'min', factor = 0.95, min_lr = 1e-5)
  # scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                        # T_0 = 8,# Number of iterations for the first restart
                                        # T_mult = 1, # A factor increases TiTi​ after a restart
                                        # eta_min = 1e-4) # Minimum learning rate

  for epoch in range(epochs):
    # train 
    model.train()
    train_running_loss = []  # 출력용
    for train_iteration, train_batch in enumerate(train_loader):
      input, target = train_batch 
      input, target = input.to(device), target.to(device)

      optimizer.zero_grad()

      preds = model(input)
      loss = loss_func(preds, target)
      train_running_loss.append(loss.item())

      loss.backward()
      optimizer.step()

      # update the tensorboarrd writer ( train )
      writer.add_scalar('train_loss', train_learning_loss, epoch * len(train_loader) + train_iteration)

    # update step based scheduler
    # scheduler.step()

    # validation 
    model.eval()
    validation_running_loss = [] # 출력용
    with torch.no_grad():
      # for val_iteration, val_batch in validation_loop:
      for val_iteration, val_batch in enumerate(val_loader):
        input, target = val_batch 
        input, target = input.to(device), target.to(device) 

        preds = model(input)
        loss = loss_func(preds, target)

        # learning rate step when plateau <- Plateau Scheduler
        scheduler.step(loss) 

        # validation_loss += loss.item()
        validation_running_loss.append(loss.item())

        # update the tensorboard writer ( val )
        writer.add_scalar('val_loss', loss.item(), epoch * len(val_loader) + val_iteration)

    # print current lr
    print('current learning rate : ',  scheduler._last_lr)
      
    # print current validation loss
    print("EPOCH %04d / %04d | %12s : %.6f | %12s : %0.6f"  % (epoch + 1, epochs, 'VALIDATION LOSS', np.mean(validation_running_loss), 'TRAIN LOS', np.mean(train_running_loss)))


# Tensorboard, summarywriter 
from torch.utils.tensorboard import SummaryWriter
base_dir = '/content/gdrive/MyDrive/i2dl/exercise_10'
log_dir = os.path.join(base_dir, 'log')
writer = SummaryWriter(log_dir = log_dir)

# train 
epochs = hparams['epochs']
train_model(model, train_loader, val_loader, loss_func, optimizer, writer, epochs)

# save the loss on tensorboard 
writer.close()

#######################################################################
#                           END OF YOUR CODE                          #
#######################################################################

# https://medium.com/analytics-vidhya/creating-a-very-simple-u-net-model-with-pytorch-for-semantic-segmentation-of-satellite-images-223aa216e705
# https://medium.com/the-owl/extracting-features-from-an-intermediate-layer-of-a-pretrained-model-in-pytorch-c00589bda32b
# https://dacon.io/codeshare/4245
# https://dacon.io/codeshare/4245

In [None]:
tensorboard --logdir ./logs

In [None]:
model = model.to(device)
test(evaluate_model(model, test_loader))

# 4. Visualization

In [None]:
visualizer(model, test_data)

## Save the Model for Submission

When you are satisfied with your training, save the model for [submission](https://i2dl.vc.in.tum.de/submission/). In order to be eligible for the bonus points you have to achieve an accuracy above __64%__.

In [None]:
os.makedirs('models', exist_ok=True)
save_model(model, "segmentation_nn.model")
checkSize(path = "./models/segmentation_nn.model")

In [None]:
from exercise_code.util.submit import submit_exercise

submit_exercise('../output/exercise10')

# Submission Instructions

Congratulations! You've just built your first semantic segmentation model with PyTorch Lightning! To complete the exercise, submit your final model to our submission portal - you probably know the procedure by now.

1. Go on [our submission page](https://i2dl.vc.in.tum.de/submission/), register for an account and login. We use your matriculation number and send an email with the login details to the mail account associated. When in doubt, login into tum online and check your mails there. You will get an ID which we need in the next step.
2. Log into [our submission page](https://i2dl.vc.in.tum.de/submission/) with your account details and upload the `zip` file. Once successfully uploaded, you should be able to see the submitted file selectable on the top.
3. Click on this file and run the submission script. You will get an email with your score as well as a message if you have surpassed the threshold.

# Submission Goals

- Goal: Implement and train a convolutional neural network for Semantic Segmentation.
- Passing Criteria: Reach **Accuracy >= 64%** on __our__ test dataset. The submission system will show you your score after you submit.
- Submission start: __January 19, 2023 - 13:00__
- Submission deadline: __January 25, 2023 - 15:59__
- You can make **$\infty$** submissions until the deadline. Your __best submission__ will be considered for bonus

# [Exercise Review](https://docs.google.com/forms/d/e/1FAIpQLScwZArz6ogLqBEj--ItB6unKcv0u9gWLj8bspeiATrDnFH9hA/viewform)

We are always interested in your opinion. Now that you have finished this exercise, we would like you to give us some feedback about the time required to finish the submission and/or work through the notebooks. Please take the short time to fill out our [review form](https://docs.google.com/forms/d/e/1FAIpQLScwZArz6ogLqBEj--ItB6unKcv0u9gWLj8bspeiATrDnFH9hA/viewform) for this exercise so that we can do better next time! :)

In [None]:

################
# def fit(model, train_data, train_loader, loss_func, optimizer, writer_train, epoch):
#   model.train()

#   # for epoch in range(epochs):
#   train_running_loss = 0.0
#   count = 0

#   # get the number of batches for the tqdm progress bar 
#   num_batches = get_num_batches(train_loader, train_data)

#   for i, train_batch in enumerate(train_loader):
#     count += 1

#     input, target = train_batch
#     input, target = input.to(device), target.to(device)

#     optimizer.zero_grad()

#     preds = model(input)
#     loss = loss_func(preds, target)
#     train_running_loss += loss.item()

#     loss.backward() # ? 
#     optimizer.step()

#     # update the tensorboard 
#     writer_train.add_scaler('train_loss', loss, epoch * len(train_loader) + i)
  
#   return train_running_loss / count



# def eval(model, val_loader, loss_func, optimizer):  
#   model.eval()

#   # for epoch in range(epochs):
#   val_running_loss = 0.0
#   count = 0
  
#   with torch.no_grad():
#     for i, val_batch in enumerate(val_loader):
#       count += 1

#       input, target = val_batch 
#       input, target = input.to(device), target.to(device)

#       preds = model(input)
#       loss = loss_func(preds, target)
#       val_running_loss += loss.item()

#       # update the tensorboard 
#       writer_train.add_scaler('val_loss', loss, epoch * len(val_loader) + i)

#   return val_running_loss / count 
  



# # Tensorboard, summarywriter 
# from torch.utils.tensorboard import SummaryWriter
# base_dir = '/content/gdrive/MyDrive/i2dl/exercise_10'
# log_dir = os.path.join(base_dir, 'log')
# writer_train = SummaryWriter(log_dir = os.path.join(log_dir, 'train'))
# writer_val = SummaryWriter(log_dir = os.path.join(log_dir, 'val'))

# # Import progressbar 
# from tqdm import tqdm

# ## train the model
# train_loss = [] 
# val_loss = []
# for epoch in range(epochs):
#   train_epoch_loss = fit(model, train_loader, loss_func, optimizer, writer_train, epoch)
#   val_epoch_loss = eval(model, val_loader, loss_func, optimizer, writer_val, epoch)
#   train_loss.append(train_epoch_loss)
#   val_loss.append(val_epoch_loss)
#   print(f'Epoch : {epoch + 1} / {epochs} :: Train Loss: {train_epoch_loss:.8f} / Validation Loss: {val_epoch_loss:.8f} ')


    


#
# trainer = pl.Trainer( max_epochs = 50,
#                       accelerator="gpu", devices=-1)


# train_dataloader = DataLoader(train_data, batch_size = hparams["batch_size"])
# val_dataloader = DataLoader(val_data, batch_size = hparams["batch_size"])

# trainer.fit(model, train_dataloader, val_dataloader)  