In [None]:
import pandas as pd
import glob

In [None]:
path_dict = {}

In [None]:
from os import walk
for (dirpath,dir_names,filenames) in walk('../input/satellite-images-of-hurricane-damage'):
    temp = []
    for i in filenames:
        temp.append(dirpath + '/' + i)
    if(len(temp) >0):
        path_dict[dirpath ] = temp

In [None]:
import numpy as np 
import pandas as pd 
import os
import cv2
from tqdm import tqdm 
import glob
import sys

import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning import loggers as pl_loggers

! pip install timm
import timm

In [None]:
class Hurricane_Data(Dataset):
    def __init__(self,path_array : list):
        self.path_array = path_array 
        self.compose = A.Compose([
            A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2()
        ])
    
    def __len__(self):
        return len(self.path_array)
    
    def __getitem__(self,index):
        filename = self.path_array[index]
        image = cv2.imread(filename)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.compose(image = image)['image']
#         image = image.view(128,128,3)
        if 'no_damage' in filename:
            label = [0]
        else:
            label = [1]
        
        return torch.tensor(image),torch.tensor(label)

In [None]:
train_damage = '../input/satellite-images-of-hurricane-damage/train_another/damage'
train_no_damage = '../input/satellite-images-of-hurricane-damage/train_another/no_damage'
train_paths = path_dict[train_damage] + path_dict[train_no_damage]

train_damage = '../input/satellite-images-of-hurricane-damage/validation_another/damage'
train_no_damage = '../input/satellite-images-of-hurricane-damage/validation_another/no_damage'
valid_paths = path_dict[train_damage] + path_dict[train_no_damage]

train_damage = '../input/satellite-images-of-hurricane-damage/test_another/damage'
train_no_damage = '../input/satellite-images-of-hurricane-damage/test_another/no_damage'
test_paths = path_dict[train_damage] + path_dict[train_no_damage]


In [71]:
class LitHurricane(pl.LightningModule):
    def __init__(self,train_paths : str ,valid_paths : str,test_paths : str):
        super(LitHurricane,self).__init__()
        
        self.train_dataset = Hurricane_Data(train_paths)
        self.valid_dataset = Hurricane_Data(valid_paths)
        self.test_dataset = Hurricane_Data(test_paths)
        
        self.model = timm.create_model('resnet34', 
                                       num_classes=1,
                                       pretrained = True
                                      )
        
        self.train_loss = nn.BCEWithLogitsLoss()
        self.valid_loss = nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        
        
        self.lr = 0.001
        
    def forward(self,batch):
        return self.model(batch)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-4, weight_decay=3e-6)
    
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset,
                             batch_size =  32,
                             shuffle = False,
                             sampler = None, 
                                 num_workers = os.cpu_count()
                                 )

        return train_loader
    
    def val_dataloader(self):
        val_loader = DataLoader(self.valid_dataset,
                               batch_size = 32,
                               shuffle = False,
                               num_workers = os.cpu_count()
                               )
        return val_loader

    def training_step(self,batch,batch_idx):
        image,labels = batch
        logits = self(image)
        loss = self.train_loss(logits,labels.type_as(logits))
        logits = logits > 0.5
        accuracy = self.train_acc(logits,labels.type_as(logits))
        
        self.log("train_loss_batch", loss,prog_bar = True)
        self.log("train_acc_batch", accuracy,prog_bar = True)
        
        return {
            'loss' : loss,
            'y_pred' : logits,
            'y_true' : labels
        }
    
    
    def training_epoch_end(self,outputs):
        accuracy = self.train_acc.compute()
        self.log("train_acc_end",accuracy,prog_bar = True)
        print(f"Train accuracy for epoch {accuracy}")

    def validation_step(self,batch,batch_idx):
        image,labels = batch
        logits = self.model(image)
        loss = self.valid_loss(logits,labels.float())
        
        logits = (logits > 0.5).int()
        accuracy = self.valid_acc(logits,labels)
        
        self.log("valid_loss_batch", loss,prog_bar = True)
        self.log("valid_acc_batch", accuracy,prog_bar = True)
        
        return {
            'loss' : loss,
            'y_pred' : logits,
            'y_true' : labels
        }
        
    
    def validation_epoch_end(self,outputs):
        accuracy = self.valid_acc.compute()
        self.log("valid_acc_end",accuracy,prog_bar = True)
        print(f"Valid accuracy for epoch {accuracy}")
        return {
            'val_loss' : outputs[0]['loss'],
            'val_acc_end' : accuracy
        }

In [None]:
from pytorch_lightning.callbacks.progress import ProgressBar

class LitProgressBar(ProgressBar):
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.leave = True
        return bar
        
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('Valid')
        return bar

        
    def training_epoch_end(self, outputs):
        self.trainer.progress_bar_callback.main_progress_bar.write(
            f"Epoch {self.trainer.current_epoch + 1} training loss={self.trainer.progress_bar_dict['train_loss']}" +
            f"Accuracy={self.trainer.progress_bar_dict['train_acc_end']}"
        )

    def validation_epoch_end(self, outputs):
        loss = torch.stack(outputs).mean()
        self.trainer.progress_bar_callback.main_progress_bar.write(
            f"Epoch {self.trainer.current_epoch + 1} validation loss={self.trainer.progress_bar_dict['val_loss']}" +
            f"Valid Accuracy = {self.trainer.progress_bar_dict['valid_acc_end']}"
        )

In [72]:
import warnings
warnings.filterwarnings('ignore')

In [73]:
if __name__ == '__main__':
    lit = LitHurricane(train_paths,valid_paths,test_paths)
    trainer = pl.Trainer(
        max_epochs = 5,
        gradient_clip_val=1,
        progress_bar_refresh_rate = 0,
        callbacks = [LitProgressBar()],
        gpus = 1
    )
    trainer.fit(lit)