In [1]:
import wandb
from wandb.fastai import WandbCallback

In [2]:
wandb.init(project='droughtwatch', name='first-run')

W&B Run: https://app.wandb.ai/akashpalrecha/droughtwatch/runs/tk6oegbu

In [3]:
from fastai.vision import *
import tifffile

In [4]:
import torchvision
import torch.nn as nn
from functools import partial

resnet_models = {18: torchvision.models.resnet18,
                 34: torchvision.models.resnet34,
                 50: torchvision.models.resnet18,
                 101: torchvision.models.resnet101,
                 152: torchvision.models.resnet152}

class Resnet_multichannel(nn.Module):
    def __init__(self, pretrained=True, encoder_depth=34, num_in_channels=4):
        super().__init__()
        
        if encoder_depth not in [18, 34, 50, 101, 152]:
            raise ValueError(f"Encoder depth {encoder_depth} specified does not match any existing Resnet models")
            
        model = resnet_models[encoder_depth](pretrained)
        
        ##For reference: layers to use (in order):
        # conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc
        
        # This is the most important line of code here. This increases the number of in channels for our network
        self.conv1 = self.increase_channels(model.conv1, num_in_channels)
        
        self.bn1 = model.bn1
        self.relu = model.relu
        self.maxpool = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool
        self.fc = model.fc
        
    def forward(self, x):
        x=x.float()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
        
    def increase_channels(self, m, num_channels=None, copy_weights=0):


        """
        takes as input a Conv2d layer and returns the a Conv2d layer with `num_channels` input channels
        and all the previous weights copied into the new layer.
        """
        # number of input channels the new module should have
        new_in_channels = num_channels if num_channels is not None else m.in_channels + 1
        
        bias = False if m.bias is None else True
        
        # Creating new Conv2d layer
        new_m = nn.Conv2d(in_channels=new_in_channels, 
                          out_channels=m.out_channels, 
                          kernel_size=m.kernel_size, 
                          stride=m.stride, 
                          padding=m.padding,
                          bias=bias)
        
        # Copying the weights from the old to the new layer
        new_m.weight[:, :m.in_channels, :, :] = m.weight.clone()
        
        #Copying the weights of the `copy_weights` channel of the old layer to the extra channels of the new layer
        for i in range(new_in_channels - m.in_channels):
            channel = m.in_channels + i
            new_m.weight[:, channel:channel+1, :, :] = m.weight[:, copy_weights:copy_weights+1, : :].clone()
        new_m.weight = nn.Parameter(new_m.weight)

        return new_m
    
def get_arch(encoder_depth, num_in_channels):
    """
    Returns just an architecture which can then be called in the usual way.
    For example:
    resnet34_4_channel = get_arch(34, 4)
    model = resnet34_4_channel(True)
    """
    return partial(Resnet_multichannel, encoder_depth=encoder_depth, num_in_channels=num_in_channels)

In [5]:
class ImageMultiList(ImageList):
    def open(self, fn):
        "Open image in `fn`, subclass and overwrite for custom behavior."
        img=tifffile.imread(str(fn))
        img=Image(img)
        return img 

In [6]:
from pathlib import Path
cur=Path.cwd()
Images=cur/'11_band_data'
class MultiDataBunch(ImageDataBunch):
    @classmethod
    def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid',
                    valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':
        "Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)."
        path=Path(path)
        il = ImageMultiList.from_folder(path)
        if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)
        else: src = il.split_by_rand_pct(valid_pct, seed)
        src = src.label_from_folder(classes=classes)
        return cls.create_from_ll(src, **kwargs)

In [7]:
data=MultiDataBunch.from_folder(Images, bs=256)

In [12]:
resnet34_11_channel=get_arch(101,11)
arch=resnet34_11_channel(True)
model=Learner(data,arch,metrics=[accuracy], callback_fns=WandbCallback)
#model=cnn_learner(data,models.resnet34,metrics=[accuracy])

In [13]:
%%wandb
# model.unfreeze()
model.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.828528,0.88547,0.666172,01:23
1,0.736149,0.714122,0.729449,01:23
2,0.585139,0.622434,0.757283,01:23


Better model found at epoch 0 with valid_loss value: 0.8854697942733765.
Better model found at epoch 1 with valid_loss value: 0.7141223549842834.
Better model found at epoch 2 with valid_loss value: 0.6224339008331299.
Loaded best saved model from /home/ubuntu/droughtwatch/wandb/run-20190903_172806-tk6oegbu/bestmodel.pth


In [14]:
model.freeze_to(-3)

In [15]:
model.fit_one_cycle(5, slice(1e-2))

epoch,train_loss,valid_loss,accuracy,time
0,0.866361,10.049053,0.582761,01:23
1,0.857722,0.973928,0.630636,01:23
2,0.757022,1.085943,0.639822,01:23
3,0.657245,0.68204,0.740119,01:23
4,0.561885,0.652438,0.748655,01:23


Better model found at epoch 0 with valid_loss value: 10.049053192138672.
Better model found at epoch 1 with valid_loss value: 0.9739278554916382.
Better model found at epoch 3 with valid_loss value: 0.6820403337478638.
Better model found at epoch 4 with valid_loss value: 0.6524384617805481.
Loaded best saved model from /home/ubuntu/droughtwatch/wandb/run-20190903_172806-tk6oegbu/bestmodel.pth


In [16]:
model.unfreeze()

In [18]:
model.fit_one_cycle(15, slice(3e-3))

epoch,train_loss,valid_loss,accuracy,time
0,0.527574,0.656812,0.748933,01:23
1,0.537658,0.658219,0.751067,01:23
2,0.534838,0.721058,0.746335,01:23
3,0.528771,0.702374,0.739562,01:23
4,0.508074,0.692996,0.745686,01:22
5,0.484402,0.717736,0.742067,01:23
6,0.444775,0.759323,0.740583,01:23
7,0.401112,0.789606,0.740397,01:23
8,0.361573,0.793437,0.753387,01:22
9,0.327547,0.85212,0.750789,01:22


Better model found at epoch 0 with valid_loss value: 0.6568122506141663.
Loaded best saved model from /home/ubuntu/droughtwatch/wandb/run-20190903_172806-tk6oegbu/bestmodel.pth


Okay, so a resnet101 overfits. Clearly. Not doing that anymore. We're done with this one

<hr>

In [22]:
wandb.init(project='droughtwatch', name='resnet50-run')

W&B Run: https://app.wandb.ai/akashpalrecha/droughtwatch/runs/3to0uv3v

In [23]:
resnet50_11_channel=get_arch(50,11)
arch=resnet50_11_channel(True)
model=Learner(data,arch,metrics=[accuracy], callback_fns=WandbCallback)
#for some reason it used resnet18

In [24]:
model.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.819278,0.929013,0.652533,00:26
1,0.681307,0.703111,0.73214,00:26
2,0.540884,0.606663,0.763964,00:26


Better model found at epoch 0 with valid_loss value: 0.9290125966072083.
Better model found at epoch 1 with valid_loss value: 0.7031105756759644.
Better model found at epoch 2 with valid_loss value: 0.6066631078720093.
Loaded best saved model from /home/ubuntu/droughtwatch/wandb/run-20190903_181428-3to0uv3v/bestmodel.pth


In [25]:
model.freeze_to(-3)

In [26]:
model.fit_one_cycle(5, slice(1e-2))

epoch,train_loss,valid_loss,accuracy,time
0,0.762917,105.537399,0.64576,00:27
1,0.91797,1.280535,0.401466,00:27
2,0.78034,0.808271,0.699295,00:26
3,0.657943,0.694838,0.73715,00:26
4,0.575928,0.645484,0.752459,00:26


Better model found at epoch 0 with valid_loss value: 105.53739929199219.
Better model found at epoch 1 with valid_loss value: 1.280535101890564.
Better model found at epoch 2 with valid_loss value: 0.8082706332206726.
Better model found at epoch 3 with valid_loss value: 0.6948381066322327.
Better model found at epoch 4 with valid_loss value: 0.6454839110374451.
Loaded best saved model from /home/ubuntu/droughtwatch/wandb/run-20190903_181428-3to0uv3v/bestmodel.pth


In [27]:
model.unfreeze()

In [None]:
model.fit_one_cycle(15, slice(3e-3))

epoch,train_loss,valid_loss,accuracy,time
0,0.544462,0.657609,0.754778,00:26
1,0.540697,0.65802,0.750881,00:26
2,0.551468,0.736394,0.71915,00:26
3,0.549219,0.73904,0.717573,00:26
4,0.523843,0.681697,0.74884,00:26
5,0.485416,0.777854,0.69781,00:26


Better model found at epoch 0 with valid_loss value: 0.6576094031333923.
