In [1]:
!pip install utils

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.depht = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, X):
        return self.depht(X)

class myUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
      super(myUNET, self).__init__()
      self.encode = nn.ModuleList()
      self.decode = nn.ModuleList()
      self.pool = nn.MaxPool2d(kernel_size=2, stride=2) #divide size per 2

      # Encode
      for feature in features:
          self.encode.append(DoubleConv(in_channels, feature))
          in_channels = feature
      
      # Decode
      for feature in reversed(features):
          self.decode.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
          self.decode.append(DoubleConv(feature*2, feature))

      self.bottleneck = DoubleConv(features[-1], features[-1]*2)
      self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, X):
        #print(X.shape)
        skip_connections = []

        for depht in self.encode:
            X = depht(X)
            skip_connections.append(X)
            X = self.pool(X)

        X = self.bottleneck(X)
        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.decode), 2):
            X = self.decode[i](X)
            #skip_connection = skip_connections[-1 - i//2]
            skip_connection = skip_connections[i//2]

            if X.shape != skip_connection.shape:
                X = torch.nn.functional.interpolate(X, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, X), dim=1)
            print(concat_skip.shape)
            X = self.decode[i+1](concat_skip)

        return self.final_conv(X)

def test():
    x = torch.randn((3, 1, 160, 160))
    m = myUNET(in_channels=1, out_channels=1)
    pred = m(x)

    assert pred.shape == x.shape

test()

torch.Size([3, 1024, 20, 20])
torch.Size([3, 512, 40, 40])
torch.Size([3, 256, 80, 80])
torch.Size([3, 128, 160, 160])


In [3]:
!unzip Dataset_Faces_training.zip
!unzip Dataset_Faces_validation.zip

Archive:  Dataset_Faces_training.zip
replace Dataset_Faces_training/1.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: Archive:  Dataset_Faces_validation.zip
replace Dataset_Faces_validation/1000.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [4]:
!unzip Dataset_Faces_Mask_training.zip
!unzip Dataset_Faces_Mask_validation.zip

Archive:  Dataset_Faces_Mask_training.zip
replace Dataset_Faces_Mask_training/1.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: Archive:  Dataset_Faces_Mask_validation.zip
replace Dataset_Faces_Mask_validation/1000.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [5]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

In [6]:
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs
)

ImportError: ignored

In [None]:
#Hyperparamters

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 218
IMAGE_WIDTH = 178
PIN_MEMORY = True
LOAD_MODEL = True
IMG_DIR_TRAIN = "/content/Dataset_Faces_training"
IMG_DIR_VAL = "/content/Dataset_Faces_validation"
MASK_DIR_TRAIN = "/content/Dataset_Faces_Mask_training"
MASK_DIR_VAL = "/content/Dataset_Faces_Mask_validation"

In [None]:
def train(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for i_batch, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.floats().unsqueeze(1).to(devide=DEVICE)

        with torch.cuda.amp.autocast():
            preds = model(data)
            loss = loss_fn(preds, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

In [None]:
train_transform = A.compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
    ]
)

validation_transform = A.compose(
    [
        A.resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
    ]
)

UNET = myUNET(in_channels=3, out_channels=1).to(DEVICE) #for multiple classes change out channels to number of classes
loss_fn = nn.BCEWithLogitsLoss # for multiple classes use cross entropy
optimizer = optim.Adam(UNET.parameters(), lr=LEARNING_RATE)

train_loader = get_loaders(
    IMG_DIR_TRAIN,
    MASK_DIR_TRAIN,
    IMG_DIR_VAL,
    MASK_DIR_VAL,
    BATCH_SIZE,
    train_transform,
    NUM_WORKERS,
    PIN_MEMORY
)

scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train(train_loader, UNET, optimizer, loss_fn, scaler)