<a href="https://colab.research.google.com/github/Locke-bot/FaceReg/blob/master/Untitled6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gc; gc.collect()
import os, tqdm, datetime
from skimage import io, transform
import torch
from torch import Tensor
from PIL import Image
import torchvision
from torch.profiler import profile, record_function, ProfilerActivity
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as Func
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
print('running on', device)

running on cuda


In [None]:
from google.colab import drive
drive.mount('content', force_remount=True)

Mounted at content


In [None]:
os.chdir('content/MyDrive/Carvana')
!ls
data_dir = os.getcwd()

colab_carvana_model.pth  train		  unet_attached      val_masks
test			 train_masks	  unet_attached.png
test_masks		 train_masks.csv  val


In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [None]:
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

In [None]:
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
m = UNet(3, 1, bilinear=False)
m.to(device)
t = torch.rand((32, 3, 160, 240))
t = t.to(device)
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
  m(t)
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

In [None]:

class Carro(Dataset):

  def __init__(self, root, mask_root, transforms=None):
        self.root, self.mask_root = root, mask_root
        self.root_list = sorted(os.listdir(root))
        self.mask_root_list = sorted(os.listdir(mask_root))
        if self.root == 'train':
          self.root_list = self.root_list[::10]
          self.mask_root_list = self.mask_root_list[::10]
        self.transform = transforms
        assert len(self.root_list) == len(self.mask_root_list)

  def __len__(self):
    return len(self.root_list)

  def __getitem__(self, idx):
        img, mask = io.imread(os.path.join(self.root, self.root_list[idx])), io.imread(os.path.join(self.mask_root, self.mask_root_list[idx]))
        #img, mask = self.root_list[idx], self.mask_root_list[idx]
        #img, mask = io.imread(self.root_list[idx]), io.imread(self.mask_root_list[idx])
        if self.transform:
          img = self.transform(img)
          mask = self.transform(mask)
        return (img, mask)

In [None]:

t = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Resize((160, 240)),
                        ])

In [None]:
tds = Carro('train', 'train_masks', t)
vds = Carro('val', 'val_masks', t)
tdl = DataLoader(tds, batch_size=4, shuffle=True)
vdl = DataLoader(vds, batch_size=4, shuffle=False)

In [None]:
vds.__len__()

In [None]:

def validate(model, test_loader):
    for name, loader in [("Test", test_loader)]:
        model.eval()
        #print('loading')
        correct = 0
        total = 0
        dice_score = 0
        val_loss = 0
        with torch.no_grad():
            for imgs, labels in loader:
                #print('working', imgs.shape)
                with torch.no_grad():
                    imgs = imgs.to(device)
                    labels = labels.to(device)
                    outputs = model(imgs)
                    #print(outputs)
                    predicted = torch.sigmoid(outputs)
                    predicted = (predicted > 0.5).float()
                    correct += (predicted == label).sum()
                    total += torch.numel(predicted)
                    dice_score += (2 * (predicted * output).sum()) / (
                       (predicted + output).sum() + 1e-8
                     )

                    val_loss += loss_fn(outputs, labels).item()
        
                    #predicted, labels = [dixt[i.item()] for i in predicted], [dixt[i.item()] for i in labels]
                    #print(list(zip(predicted, label)))
                    #break
                    #print("Mini Accuracy {}: {:.2f}".format(name , correct / total))
        #print("Accuracy {}: {:.2f}".format(name , correct / total))
        print(f'validation: {correct}/{total}, dice_score: {dice_score/len(loader)}')
        return (correct/total), val_loss/len(test_loader)

In [None]:

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter
        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]

def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
    return dice / input.shape[1]

def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [None]:
net = UNet(3, 1)⁷
amp = False
epochs = 5
train_loader = tdl
learning_rate = 1e-3
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0

# 5. Begin training
count = 0
n_train = len(tds)
for epoch in range(epochs):
        count = 0
        dice_score = 0
        net.train()
        epoch_loss = 0
        with tqdm.tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for images, true_masks in train_loader:

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                count += 1
        print(dice_loss, epoch_loss/count)

In [None]:
#model = Coder(features=[64, 128, 256])
model = UNet(3, 1, bilinear=False)
model.to(device)
#loss_fn = nn.NLLLoss()
loss_fn = nn.BCEWithLogitsLoss()
print('here')
lr = 3*10**-4
#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optimizer = optim.Adam(model.parameters())
train_loader = tdl
val_loader = vdl
epochs = 500
best_model = None
best_acc = 0.
import tqdm
for epoch in range(epochs):
    loss_train = 0
    dice_score = 0
    total = 0
    correct = 0
    train_loader = tqdm.tqdm(train_loader)
    for imgs, label in train_loader:
        model.train()
        imgs = imgs.to(device)
        label = label.to(device)
        label /= 255
        #print(label[0], label[0].dtype, output.dtype)
        #print(imgs.shape)
        batch_size = imgs.shape[0]
        output = model(imgs)
        predicted = torch.sigmoid(output)
        predicted = (predicted > 0.5).float()
        correct += (predicted == label).sum()
        total += torch.numel(predicted)
        # print(output)
        dice_score += (2 * (predicted * label).sum()) / (
                (predicted + label).sum() + 1e-8
            )

        loss = loss_fn(output, label)        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_train += loss.item()
    print(datetime.datetime.now(), epoch,
        loss_train / len(train_loader),
        (correct/total)
    )
    print(f"dice_score_training: {dice_score/len(train_loader)}")
    val_acc, val_loss = validate(model, val_loader)
    if val_acc > best_acc:
        best_model = model.state_dict()
        torch.save(best_model, 'colab_carvana_model.pth')
        best_acc = val_acc
    #print(f"val_acc: {val_acc} ; val_loss: {val_loss} for {epoch} epoch; best_acc: {best_acc}")
    print(f"val_acc: {val_acc} and val_loss: {val_loss} for {epoch} epoch; best_acc: {best_acc}")

here


100%|██████████| 103/103 [03:52<00:00,  2.25s/it]


2021-12-08 16:25:30.754342 0 0.1595181385145604 tensor(0.7872, device='cuda:0')
dice_score_training: 0.0006792516214773059
validation: 14563933/19200000, dice_score: 0.0008477981318719685
val_acc: 0.7585381865501404 and val_loss: 0.6503865199089051 for 0 epoch; best_acc: 0.7585381865501404


100%|██████████| 103/103 [01:37<00:00,  1.05it/s]


2021-12-08 16:31:40.426153 1 0.032587920053346646 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15928500/19200000, dice_score: 0.0
val_acc: 0.8296093940734863 and val_loss: 0.8933776025772094 for 1 epoch; best_acc: 0.8296093940734863


100%|██████████| 103/103 [01:37<00:00,  1.05it/s]


2021-12-08 16:34:34.018549 2 0.015666116634835896 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 16692500/19200000, dice_score: 0.0
val_acc: 0.8694010376930237 and val_loss: 0.9916566491127015 for 2 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:37<00:00,  1.05it/s]


2021-12-08 16:37:27.246727 3 0.010698915992691679 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 16015500/19200000, dice_score: 0.0
val_acc: 0.8341405987739563 and val_loss: 1.0719099535942078 for 3 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:36<00:00,  1.06it/s]


2021-12-08 16:40:19.139000 4 0.0086209952912646 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15451000/19200000, dice_score: 0.0
val_acc: 0.8047395944595337 and val_loss: 1.1270350580215454 for 4 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 16:43:10.774546 5 0.007601583313233066 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 13689000/19200000, dice_score: 0.0
val_acc: 0.7129687666893005 and val_loss: 1.1328835821151733 for 5 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 16:46:03.032055 6 0.0069955273133530775 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15448500/19200000, dice_score: 0.0
val_acc: 0.8046093583106995 and val_loss: 1.1464912815093995 for 6 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:38<00:00,  1.04it/s]


2021-12-08 16:48:57.050443 7 0.006622842671259225 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 14349000/19200000, dice_score: 0.0
val_acc: 0.7473437190055847 and val_loss: 1.1501424932479858 for 7 epoch; best_acc: 0.8694010376930237


100%|██████████| 103/103 [01:38<00:00,  1.05it/s]


2021-12-08 16:51:50.032754 8 0.006348026699377495 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 16768500/19200000, dice_score: 0.0
val_acc: 0.8733593821525574 and val_loss: 1.1749374361038207 for 8 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 16:54:43.272482 9 0.006188447953391712 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 14450000/19200000, dice_score: 0.0
val_acc: 0.7526041865348816 and val_loss: 1.1980934205055236 for 9 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:36<00:00,  1.06it/s]


2021-12-08 16:57:33.903762 10 0.00608313588190426 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 14921500/19200000, dice_score: 0.0
val_acc: 0.7771614789962769 and val_loss: 1.163039876461029 for 10 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 17:00:26.196251 11 0.005953864637509133 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15275500/19200000, dice_score: 0.0
val_acc: 0.7955989837646484 and val_loss: 1.189349681377411 for 11 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:38<00:00,  1.05it/s]


2021-12-08 17:03:18.602619 12 0.005866850631365788 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15427000/19200000, dice_score: 0.0
val_acc: 0.8034895658493042 and val_loss: 1.1814170560836792 for 12 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:35<00:00,  1.08it/s]


2021-12-08 17:06:07.850221 13 0.005809774907143081 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 14349000/19200000, dice_score: 0.0
val_acc: 0.7473437190055847 and val_loss: 1.1769671301841735 for 13 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:35<00:00,  1.08it/s]


2021-12-08 17:08:56.670685 14 0.005749278290432344 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15391000/19200000, dice_score: 0.0
val_acc: 0.8016145825386047 and val_loss: 1.192813573360443 for 14 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:36<00:00,  1.06it/s]


2021-12-08 17:11:46.617956 15 0.005705735659751209 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15592000/19200000, dice_score: 0.0
val_acc: 0.8120833039283752 and val_loss: 1.1851680855751037 for 15 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 17:14:39.852500 16 0.005671640084966004 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15591000/19200000, dice_score: 0.0
val_acc: 0.8120312690734863 and val_loss: 1.1903318510055543 for 16 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:36<00:00,  1.07it/s]


2021-12-08 17:17:30.352636 17 0.005642141066310765 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15498500/19200000, dice_score: 0.0
val_acc: 0.807213544845581 and val_loss: 1.1726024417877197 for 17 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:36<00:00,  1.06it/s]


2021-12-08 17:20:21.101598 18 0.0056132681043909016 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15912500/19200000, dice_score: 0.0
val_acc: 0.8287760615348816 and val_loss: 1.1832067446708678 for 18 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 17:23:12.082610 19 0.005599157785255353 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15547500/19200000, dice_score: 0.0
val_acc: 0.809765636920929 and val_loss: 1.1716214394569398 for 19 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:35<00:00,  1.07it/s]


2021-12-08 17:26:02.208964 20 0.005577946615233583 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15547500/19200000, dice_score: 0.0
val_acc: 0.809765636920929 and val_loss: 1.1788064270019531 for 20 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 17:28:53.099851 21 0.005564295912830575 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 15270500/19200000, dice_score: 0.0
val_acc: 0.79533851146698 and val_loss: 1.1811811456680297 for 21 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:37<00:00,  1.06it/s]


2021-12-08 17:31:45.463678 22 0.005554960990885219 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0
validation: 14663500/19200000, dice_score: 0.0
val_acc: 0.7637239694595337 and val_loss: 1.1755941338539124 for 22 epoch; best_acc: 0.8733593821525574


100%|██████████| 103/103 [01:35<00:00,  1.08it/s]


2021-12-08 17:34:36.209789 23 0.005571297114461805 tensor(0.7889, device='cuda:0')
dice_score_training: 0.0


KeyboardInterrupt: ignored