# Training Segmentation from timm in Pytorch
> This is a simple example of training segmentation model using timm in Pytorch

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| default_exp xception_segmentation

In [None]:
#| export
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
#| export
from cv_tools.imports import *

In [None]:
#| export
class XceptionBinarySegmentation(nn.Module):
    def __init__(self):
        super(XceptionBinarySegmentation, self).__init__()
        # Create Xception model from scratch with single-channel input
        self.xception = timm.create_model(
            'xception', 
            pretrained=False, 
            in_chans=1, 
            features_only=True)
        
        self.decoder = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 1, kernel_size=1)
        )
        
        self.upsample = nn.Upsample(size=(1152, 1632), mode='bilinear', align_corners=True)

    def forward(self, x):
        features = self.xception(x)
        x = self.decoder(features[-1])
        x = self.upsample(x)
        return x

In [None]:
trn_transforms = A.Compose([
            #A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=0, p=0.25),
            A.Perspective(p=0.25),
            #A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            #A.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]),
            ToTensorV2()
            ])

In [None]:
val_transforms = A.Compose([
            #A.Resize(256, 256),
            #A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            #A.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]),
            ToTensorV2()
            ])

In [None]:
from segmentation_test.dataloader_creation import *

In [None]:
image_path = Path(Path.cwd().parent, 'data/images')
mask_path = Path(Path.cwd().parent, 'data/masks')
image_path

Path('/home/user/Schreibtisch/projects/git_data/segmentation_test/data/images')

In [None]:
train_dl, val_dl = create_pytorch_dataloader(
    split_type='random',
	split_per=0.8,
	exts='.png',
	batch_size=4,
	image_path=image_path,
	mask_path=mask_path,
	trn_transforms=trn_transforms,
	val_transforms=val_transforms,
	collate_fn=repeat_collate_fn
)

 Number of images found = 7901
 training dataset length = 6320 and validation dataset length=  1581


In [None]:
#| hide
import nbdev; nbdev.nbdev_export('12_xception_segmentation.ipynb')