#### Importing Modules

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os 
import torch
from torch import nn
import torch.nn.functional as F
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

#import os
##for dirname, _, filenames in os.walk('/kaggle/input'):
  #  for filename in filenames:
   #     print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
!cp -r /kaggle/input/segment-anything-zip-file/segment-anything /kaggle/input/segment-anything

cp: cannot create directory '/kaggle/input/segment-anything': Read-only file system


In [None]:
#validation_meta = pd.read_csv(r'/kaggle/input/google-research-identify-contrails-reduce-global-warming/validation_metadata.json')
#train_metadata = pd.read_csv(r'/kaggle/input/google-research-identify-contrails-reduce-global-warming/train_metadata.json')
sample_submission = pd.read_csv(r'/kaggle/input/google-research-identify-contrails-reduce-global-warming/sample_submission.csv')

#### Configuration

In [None]:
config = {
    'batch_size': 8,
    'learning_rate': 0.001,
    'n_epochs': 10,
    # Add any other hyperparameters you want to track
}

# Use the hyperparameters from the config dictionary
batch_size = config['batch_size']
learning_rate = config['learning_rate']
n_epochs = config['n_epochs']

#### Setting up lazy-loading and preprocessing functions

In [None]:
from torch.utils.data import Dataset, DataLoader

_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)

def normalize_range(data, bounds):
    """Maps data to the range [0, 1]."""
    return (data - bounds[0]) / (bounds[1] - bounds[0])

def getdata(obs_path):
    bands_data = {}

    # Load all band data
    for filename in os.listdir(obs_path):
        if "band" in filename:
            band_name = filename.split('.')[0]  # get the name of the band (excluding the .npy extension)
            file_path = os.path.join(obs_path, filename)  # full path of the file
            band_data = np.load(file_path)  # load the band data
            bands_data[band_name] = band_data  # store the band data in the dictionary
    
    # Load the aggregated contrail markings as labels
    label_path = os.path.join(obs_path, 'human_pixel_masks.npy')
    if os.path.exists(label_path):
        labels = np.load(label_path)
    else:
        labels = None
    return bands_data, labels

def get_ash_color_images(bands_data,get_mask_frame_only = False) -> np.array:
    band11 = bands_data['band_11']
    band14 = bands_data['band_14']
    band15 = bands_data['band_15']
    
    if get_mask_frame_only:
        band11 = band11[:,:,4]
        band14 = band14[:,:,4]
        band15 = band15[:,:,4]

    r = normalize_range(band15 - band14, _TDIFF_BOUNDS)
    g = normalize_range(band14 - band11, _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(band14, _T11_BOUNDS)
    false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
    return false_color

def preprocess_func(bands_data):
    stacked_data = get_ash_color_images(bands_data,get_mask_frame_only = True)
    stacked_data = stacked_data.transpose(2, 0, 1)
    return stacked_data


<font size="4" face="verdana">

| **Band Number** | **Description** | **Additional Details** |
| --- | --- | --- |
| **Band 8** | "Upper-Level Tropospheric Water Vapor" Band | This band is helpful in tracking upper-level atmospheric moisture and jet stream winds, aiding in the forecasting of severe weather events and heavy rainfall. |
| **Band 9** | "Mid-Level Tropospheric Water Vapor" Band | This band captures images of mid-tropospheric moisture and atmospheric motion, valuable for identifying features such as tropical cyclones and thunderstorms. |
| **Band 10** | "Lower-level Water Vapor" Band | The lower-level water vapor band assists in identifying low-level moisture content, aiding in the prediction of fog, frost, and low clouds. |
| **Band 11** | "Cloud-Top Phase" Band | This band helps to determine cloud phases (water, mixed, ice) and heights, crucial for aviation safety and general weather prediction. |
| **Band 12** | "Ozone Band" | The ozone band detects ozone concentration in the atmosphere, offering insights into ozone layer health and aiding in the prediction of UV index and air quality. |
| **Band 13** | "Clean" IR Longwave Window Band | Primarily used for detection of clouds at all levels, sea surface temperature, and rainfall. |
| **Band 14** | IR Longwave Window Band | This band is used for surface and cloud top temperature estimates, identifying cloud types and cloud motion. |
| **Band 15** | "Dirty" Longwave Window Band | Used for estimation of lower-tropospheric water vapor, volcanic ash detection, and for improved rainfall estimation. |
| **Band 16** | "CO2" Longwave Infrared | This band aids in the estimation of cloud height and temperature, especially for high-level clouds. |

</font>

In [None]:
class ContrailsData(Dataset):
    def __init__(self, data_paths, preprocess_func, train_validation):
        # Initialize your data, download, etc.
        self.data_paths = data_paths
        self.preprocess_func = preprocess_func
        self.train_validation = train_validation
    
    def __len__(self):
        return len(self.data_paths['path'])

    def __getitem__(self, index):
        obs_path = self.data_paths.iloc[index]['path']
        bands_data, labels = getdata(obs_path)
        data = self.preprocess_func(bands_data)
        labels = labels.transpose(2, 0, 1)
        labels = torch.from_numpy(labels)
        labels = labels.float()
        return data, labels


def get_data_paths(root_dir):
    all_dirs = []
    for path in os.listdir(root_dir):
        full_path = os.path.join(root_dir, path)
        if os.path.isdir(full_path):
            all_dirs.append(full_path)
    return pd.DataFrame(all_dirs, columns=['path'])

root_dir_train = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/train'
root_dir_validation = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/validation'

train_data_paths = get_data_paths(root_dir_train)
validation_data_paths = get_data_paths(root_dir_validation)

In [None]:
train_dataset = ContrailsData(train_data_paths, preprocess_func, 'train')
validation_dataset = ContrailsData(validation_data_paths, preprocess_func, 'validation')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [None]:
dataiter = iter(train_loader)


In [None]:
print(f'''shape of input: {data.shape}
shape of labels: {labels.shape}''')

#### Defining Models

*  Yet to try advanced Models

In [None]:
class SimpleConv2D(nn.Module):
    def __init__(self):
        super(SimpleConv2D, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)

        self.conv_final = nn.Conv2d(256, 1, kernel_size=1, stride=1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.conv_final(x)
        
        # Add a Sigmoid activation function as the last layer
        #x = torch.sigmoid(x)
        
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1)
        
        self.upconv5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        
        self.conv6 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv9 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        
        self.conv_final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self,x):
        x1 = F.relu(self.conv1(x))
        x = F.max_pool2d(x1,kernel_size=(2))
        
        x2 = F.relu(self.conv2(x))
        x = F.max_pool2d(x2,kernel_size=(2))
        
        x3 = F.relu(self.conv3(x))
        x = F.max_pool2d(x3,kernel_size=(2))
        
        x4 = F.relu(self.conv4(x))
        x = F.max_pool2d(x4,kernel_size=(2))
        
        x5 = F.relu(self.conv5(x))
        
        x6 = torch.cat([x4,self.upconv5(x5)],dim=-3)
        x6 = F.relu(self.conv6(x6))
        
        x7 = torch.cat([x3,self.upconv4(x6)],dim=-3)
        x7 = F.relu(self.conv7(x7))
        
        x8 = torch.cat([x2,self.upconv3(x7)],dim=-3)
        x8 = F.relu(self.conv8(x8))
        
        x9 = torch.cat([x1,self.upconv2(x8)],dim=-3)
        x9 = F.relu(self.conv9(x9))
        
        out=self.conv_final(x9)
        
        return out

#### Training

*  Had to decrease batch size and empty cache because of memory issues. 


In [None]:
import torch.nn.functional as F
# Define the device for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize your model
#model = SimpleConv2D().to(device)
model  = UNet().to(device)

# Define loss function
criterion = nn.BCEWithLogitsLoss()

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# Prepare your data loaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

# Training loop
for epoch in range(n_epochs):
    model.train()
    running_loss = 0
    batches = 0
    for images, labels in train_dataloader:
        torch.cuda.empty_cache()
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        #print(images.shape)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        batches+=1
        if batches%100 == 0: print(f'batches ran: {batches}')

    epoch_loss = running_loss / len(train_dataloader.dataset)
    print('Train Loss: {:.4f}'.format(epoch_loss))

    # Validation loop
    model.eval()
    with torch.no_grad():
        running_loss = 0
        for images, labels in validation_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(validation_dataloader.dataset)
        print('Validation Loss: {:.4f}'.format(epoch_loss))