In [None]:
!pip install torch==2.1



In [None]:
!pip install tensorboardX
!pip install gdown

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m842.4 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [None]:
!pip install kaggle



In [None]:
from google.colab import files

uploaded = files.upload()




Saving kaggle.json to kaggle.json


In [None]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# List available datasets
!kaggle datasets list

# Download a dataset
!kaggle competitions download -c imaterialist-fashion-2019-FGVC6

# Unzip the downloaded dataset
!unzip -q imaterialist-fashion-2019-FGVC6.zip -d ./dataset


ref                                                    title                                            size  lastUpdated          downloadCount  voteCount  usabilityRating  
-----------------------------------------------------  ----------------------------------------------  -----  -------------------  -------------  ---------  ---------------  
thedrcat/daigt-v2-train-dataset                        DAIGT V2 Train Dataset                           29MB  2023-11-16 01:38:36           1101        119  1.0              
muhammadbinimran/housing-price-prediction-data         Housing Price Prediction Data                   763KB  2023-11-21 17:56:32           4046         81  1.0              
maso0dahmed/video-games-data                           Video Games Data                                  5MB  2023-11-25 19:08:46           1090         34  1.0              
carlmcbrideellis/llm-7-prompt-training-dataset         LLM: 7 prompt training dataset                   41MB  2023-11-15 07:3

In [None]:
import os
os.remove('imaterialist-fashion-2019-FGVC6.zip')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2d(
            in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
        )
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout


## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):

    src = F.upsample(src, size=tar.shape[2:], mode="bilinear")

    return src


### RSU-7 ###
class RSU7(nn.Module):  # UNet07DRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
        hx6dup = _upsample_like(hx6d, hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        """
        del hx1, hx2, hx3, hx4, hx5, hx6, hx7
        del hx6d, hx5d, hx3d, hx2d
        del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
        """

        return hx1d + hxin


### RSU-6 ###
class RSU6(nn.Module):  # UNet06DRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        """
        del hx1, hx2, hx3, hx4, hx5, hx6
        del hx5d, hx4d, hx3d, hx2d
        del hx2dup, hx3dup, hx4dup, hx5dup
        """

        return hx1d + hxin


### RSU-5 ###
class RSU5(nn.Module):  # UNet05DRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        """
        del hx1, hx2, hx3, hx4, hx5
        del hx4d, hx3d, hx2d
        del hx2dup, hx3dup, hx4dup
        """

        return hx1d + hxin


### RSU-4 ###
class RSU4(nn.Module):  # UNet04DRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        """
        del hx1, hx2, hx3, hx4
        del hx3d, hx2d
        del hx2dup, hx3dup
        """

        return hx1d + hxin


### RSU-4F ###
class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))

        """
        del hx1, hx2, hx3, hx4
        del hx3d, hx2d
        """

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(U2NET, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)

    def forward(self, x):

        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # -------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        """
        del hx1, hx2, hx3, hx4, hx5, hx6
        del hx5d, hx4d, hx3d, hx2d, hx1d
        del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
        """

        return d0, d1, d2, d3, d4, d5, d6

In [None]:
import os
import gdown
import torch




os.makedirs("prev_checkpoints", exist_ok=True)
gdown.download(
    "https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
    "./prev_checkpoints/u2net.pth",
    quiet=False,
)

def save_checkpoint(model, save_path):
    print(save_path)
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    torch.save(model.state_dict(), save_path)

u_net = U2NET(in_ch=3, out_ch=4)
save_checkpoint(u_net, os.path.join("prev_checkpoints", "u2net_random.pth"))

# u2net.pth contains trained weights
trained_net_pth = os.path.join("prev_checkpoints", "u2net.pth")
# u2net_random.pth contains random weights
custom_net_pth = os.path.join("prev_checkpoints", "u2net_random.pth")

net_state_dict = torch.load(trained_net_pth)
count = 0
for k, v in net_state_dict.items():
    count += 1
print("Total number of layers in trained model are: {}".format(count))

custom_state_dict = torch.load(custom_net_pth)
count = 0
for k, v in custom_state_dict.items():
    count += 1
print("Total number of layers in trained model are: {}".format(count))

total_count = 0
update_count = 0
for k, v in net_state_dict.items():
    total_count += 1
    #name = k[7:]
    if custom_state_dict[k].shape == v.shape:
        update_count += 1
        custom_state_dict[k] = v

print(
    "Out of {} layers in custom network, {} layers weights are recovered from trained model".format(
        total_count, update_count
    )
)
torch.save(
    custom_state_dict, os.path.join("prev_checkpoints", "cloth_segm_unet_surgery.pth")
)
print("cloth_segm_unet_surgery.pth is generated in prev_checkpoints directory!")

Downloading...
From: https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ
To: /content/prev_checkpoints/u2net.pth
100%|██████████| 176M/176M [00:01<00:00, 159MB/s]


prev_checkpoints/u2net_random.pth
Total number of layers in trained model are: 686
Total number of layers in trained model are: 798
Out of 686 layers in custom network, 672 layers weights are recovered from trained model
cloth_segm_unet_surgery.pth is generated in prev_checkpoints directory!


In [None]:
# Print keys for the trained model
print("Keys in trained model:")
print(list(net_state_dict.keys()))

# Print keys for the randomly initialized model
print("Keys in randomly initialized model:")
print(list(custom_state_dict.keys()))


Keys in trained model:
['stage1.rebnconvin.conv_s1.weight', 'stage1.rebnconvin.conv_s1.bias', 'stage1.rebnconvin.bn_s1.weight', 'stage1.rebnconvin.bn_s1.bias', 'stage1.rebnconvin.bn_s1.running_mean', 'stage1.rebnconvin.bn_s1.running_var', 'stage1.rebnconv1.conv_s1.weight', 'stage1.rebnconv1.conv_s1.bias', 'stage1.rebnconv1.bn_s1.weight', 'stage1.rebnconv1.bn_s1.bias', 'stage1.rebnconv1.bn_s1.running_mean', 'stage1.rebnconv1.bn_s1.running_var', 'stage1.rebnconv2.conv_s1.weight', 'stage1.rebnconv2.conv_s1.bias', 'stage1.rebnconv2.bn_s1.weight', 'stage1.rebnconv2.bn_s1.bias', 'stage1.rebnconv2.bn_s1.running_mean', 'stage1.rebnconv2.bn_s1.running_var', 'stage1.rebnconv3.conv_s1.weight', 'stage1.rebnconv3.conv_s1.bias', 'stage1.rebnconv3.bn_s1.weight', 'stage1.rebnconv3.bn_s1.bias', 'stage1.rebnconv3.bn_s1.running_mean', 'stage1.rebnconv3.bn_s1.running_var', 'stage1.rebnconv4.conv_s1.weight', 'stage1.rebnconv4.conv_s1.bias', 'stage1.rebnconv4.bn_s1.weight', 'stage1.rebnconv4.bn_s1.bias', 's

In [None]:
import os
import copy
import cv2
import numpy as np
from collections import OrderedDict

import torch


def load_checkpoint(model, checkpoint_path):
    if not os.path.exists(checkpoint_path):
        print("----No checkpoints at given path----")
        return
    model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device("cpu")))
    print("----checkpoints loaded from path: {}----".format(checkpoint_path))
    return model


def load_checkpoint_mgpu(model, checkpoint_path):
    if not os.path.exists(checkpoint_path):
        print("----No checkpoints at given path----")
        return
    model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    new_state_dict = OrderedDict()
    for k, v in model_state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    print("----checkpoints loaded from path: {}----".format(checkpoint_path))
    return model


def save_checkpoint(model, save_path):
    print(save_path)
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    torch.save(model.state_dict(), save_path)


def save_checkpoints(opt, itr, net):
    save_checkpoint(
        net,
        os.path.join(opt.save_dir, "checkpoints", "itr_{:08d}_u2net.pth".format(itr)),
    )

#-----------------
import os
import math
import numpy as np
import random
import pickle
import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()
    if world_size == 1:
        return

    dist.barrier()


def cleanup(distributed):
    if distributed:
        dist.destroy_process_group()


def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()
#-----------------
import os
import numpy as np
from PIL import Image

import torch
from torch.utils.tensorboard import SummaryWriter

# Adding image in tensorboardX


def tensor_for_board(img_tensor):
    # map into [0,1]
    tensor = (img_tensor.clone()+1) * 0.5
    tensor.cpu().clamp(0, 1)

    if tensor.size(1) == 1:
        tensor = tensor.repeat(1, 3, 1, 1)

    return tensor


def tensor_list_for_board(img_tensors_list):
    grid_h = len(img_tensors_list)
    grid_w = max(len(img_tensors) for img_tensors in img_tensors_list)

    batch_size, channel, height, width = tensor_for_board(
        img_tensors_list[0][0]).size()
    canvas_h = grid_h * height
    canvas_w = grid_w * width
    canvas = torch.FloatTensor(
        batch_size, channel, canvas_h, canvas_w).fill_(0.5)
    for i, img_tensors in enumerate(img_tensors_list):
        for j, img_tensor in enumerate(img_tensors):
            offset_h = i * height
            offset_w = j * width
            tensor = tensor_for_board(img_tensor)
            canvas[:, :, offset_h: offset_h + height,
                   offset_w: offset_w + width].copy_(tensor)

    return canvas


def board_add_image(board, tag_name, img_tensor, step_count):
    tensor = tensor_for_board(img_tensor)

    for i, img in enumerate(tensor):
        board.add_image('%s/%03d' % (tag_name, i), img, step_count)


def board_add_images(board, tag_name, img_tensors_list, step_count):
    tensor = tensor_list_for_board(img_tensors_list)

    for i, img in enumerate(tensor):
        board.add_image('%s/%03d' % (tag_name, i), img, step_count)

In [None]:
import os.path as osp
import os


class parser(object):
    def __init__(self):
        self.name = "training_cloth_segm_u2net_exp1"  # Expriment name
        self.image_folder = "/content/dataset/train/"  # image folder path
        self.df_path = "/content/dataset/train.csv"  # label csv path
        self.distributed = False  # True for multi gpu training
        self.isTrain = True

        self.fine_width = 192 * 4
        self.fine_height = 192 * 4

        # Mean std params
        self.mean = 0.5
        self.std = 0.5

        self.batchSize = 2  # 12
        self.nThreads = 2  # 3
        self.max_dataset_size = float("inf")

        self.serial_batches = False
        self.continue_train = True
        if self.continue_train:
            self.unet_checkpoint = "prev_checkpoints/cloth_segm_unet_surgery.pth"

        self.save_freq = 1000
        self.print_freq = 10
        self.image_log_freq = 100

        self.iter = 100000
        self.lr = 0.0002
        self.clip_grad = 5

        self.logs_dir = osp.join("logs", self.name)
        self.save_dir = osp.join("results", self.name)

In [None]:
class BaseDataLoader:
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data():
        return None

In [None]:
import os
from PIL import Image
import cv2
import numpy as np
import random

import torch
import torch.utils.data as data
import torchvision.transforms as transforms


class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return "BaseDataset"

    def initialize(self, opt):
        pass


class Rescale_fixed(object):
    """Rescale the input image into given size.

    Args:
        (w,h) (tuple): output size or x (int) then resized will be done in (x,x).
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, image):
        return image.resize(self.output_size, Image.BICUBIC)


class Rescale_custom(object):
    """Rescale the input image and target image into randomly selected size with lower bound of min_size arg.

    Args:
        min_size (int): Minimum desired output size.
    """

    def __init__(self, min_size, max_size):
        assert isinstance(min_size, (int, float))
        self.min_size = min_size
        self.max_size = max_size

    def __call__(self, sample):

        input_image, target_image = sample["input_image"], sample["target_image"]

        assert input_image.size == target_image.size
        w, h = input_image.size

        # Randomly select size to resize
        if min(self.max_size, h, w) > self.min_size:
            self.output_size = np.random.randint(
                self.min_size, min(self.max_size, h, w)
            )
        else:
            self.output_size = self.min_size

        # calculate new size by keeping aspect ratio same
        if h > w:
            new_h, new_w = self.output_size * h / w, self.output_size
        else:
            new_h, new_w = self.output_size, self.output_size * w / h

        new_w, new_h = int(new_w), int(new_h)
        input_image = input_image.resize((new_w, new_h), Image.BICUBIC)
        target_image = target_image.resize((new_w, new_h), Image.BICUBIC)
        return {"input_image": input_image, "target_image": target_image}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __init__(self):
        self.totensor = transforms.ToTensor()

    def __call__(self, sample):
        input_image, target_image = sample["input_image"], sample["target_image"]

        return {
            "input_image": self.totensor(input_image),
            "target_image": self.totensor(target_image),
        }


class RandomCrop_custom(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

        self.randomcrop = transforms.RandomCrop(self.output_size)

    def __call__(self, sample):
        input_image, target_image = sample["input_image"], sample["target_image"]
        cropped_imgs = self.randomcrop(torch.cat((input_image, target_image)))

        return {
            "input_image": cropped_imgs[
                :3,
                :,
            ],
            "target_image": cropped_imgs[
                3:,
                :,
            ],
        }


class Normalize_custom(object):
    """Normalize given dict into given mean and standard dev

    Args:
        mean (tuple or int): Desired mean to substract from dict's tensors
        std (tuple or int): Desired std to divide from dict's tensors
    """

    def __init__(self, mean, std):
        assert isinstance(mean, (float, tuple))
        if isinstance(mean, float):
            self.mean = (mean, mean, mean)
        else:
            assert len(mean) == 3
            self.mean = mean

        if isinstance(std, float):
            self.std = (std, std, std)
        else:
            assert len(std) == 3
            self.std = std

        self.normalize = transforms.Normalize(self.mean, self.std)

    def __call__(self, sample):
        input_image, target_image = sample["input_image"], sample["target_image"]

        return {
            "input_image": self.normalize(input_image),
            "target_image": self.normalize(target_image),
        }


class Normalize_image(object):
    """Normalize given tensor into given mean and standard dev

    Args:
        mean (float): Desired mean to substract from tensors
        std (float): Desired std to divide from tensors
    """

    def __init__(self, mean, std):
        assert isinstance(mean, (float))
        if isinstance(mean, float):
            self.mean = mean

        if isinstance(std, float):
            self.std = std

        self.normalize_1 = transforms.Normalize(self.mean, self.std)
        self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
        self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)

    def __call__(self, image_tensor):
        if image_tensor.shape[0] == 1:
            return self.normalize_1(image_tensor)

        elif image_tensor.shape[0] == 3:
            return self.normalize_3(image_tensor)

        elif image_tensor.shape[0] == 18:
            return self.normalize_18(image_tensor)

        else:
            assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"

###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    f = dir.split('/')[-1].split('_')[-1]
    print(dir, f)
    dirs = os.listdir(dir)
    for img in dirs:
        path = os.path.join(dir, img)
        images.append(path)
    return images


def make_dataset_test(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    f = dir.split('/')[-1].split('_')[-1]
    for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
        if f == 'label' or f == 'labelref':
            img = str(i) + '.png'
        else:
            img = str(i) + '.jpg'
        path = os.path.join(dir, img)
        # print(path)
        images.append(path)
    return images


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)

In [None]:


import os
import cv2
import json
import itertools
import collections
from tqdm import tqdm

import pandas as pd
import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms


class AlignedDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.image_dir = opt.image_folder
        self.df_path = opt.df_path
        self.width = opt.fine_width
        self.height = opt.fine_height

        # for rgb imgs

        transforms_list = []
        transforms_list += [transforms.ToTensor()]
        transforms_list += [Normalize_image(opt.mean, opt.std)]
        self.transform_rgb = transforms.Compose(transforms_list)

        self.df = pd.read_csv(self.df_path)
        self.image_info = collections.defaultdict(dict)
        self.df["CategoryId"] = self.df.ClassId.apply(lambda x: str(x).split("_")[0])
        temp_df = (
            self.df.groupby("ImageId")["EncodedPixels", "CategoryId"]
            .agg(lambda x: list(x))
            .reset_index()
        )
        size_df = self.df.groupby("ImageId")["Height", "Width"].mean().reset_index()
        temp_df = temp_df.merge(size_df, on="ImageId", how="left")
        for index, row in tqdm(temp_df.iterrows(), total=len(temp_df)):
            image_id = row["ImageId"]
            image_path = os.path.join(self.image_dir, image_id)
            self.image_info[index]["image_id"] = image_id
            self.image_info[index]["image_path"] = image_path
            self.image_info[index]["width"] = self.width
            self.image_info[index]["height"] = self.height
            self.image_info[index]["labels"] = row["CategoryId"]
            self.image_info[index]["orig_height"] = row["Height"]
            self.image_info[index]["orig_width"] = row["Width"]
            self.image_info[index]["annotations"] = row["EncodedPixels"]

        self.dataset_size = len(self.image_info)

    def __getitem__(self, index):
        # load images ad masks
        idx = index
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        img = img.resize((self.width, self.height), resample=Image.BICUBIC)
        image_tensor = self.transform_rgb(img)

        info = self.image_info[idx]
        mask = np.zeros(
            (len(info["annotations"]), self.width, self.height), dtype=np.uint8
        )
        labels = []
        for m, (annotation, label) in enumerate(
            zip(info["annotations"], info["labels"])
        ):
            sub_mask = self.rle_decode(
                annotation, (info["orig_height"], info["orig_width"])
            )
            sub_mask = Image.fromarray(sub_mask)
            sub_mask = sub_mask.resize(
                (self.width, self.height), resample=Image.BICUBIC
            )
            mask[m, :, :] = sub_mask
            labels.append(int(label) + 1)

        num_objs = len(labels)
        boxes = []
        new_labels = []
        new_masks = []

        for i in range(num_objs):
            try:
                pos = np.where(mask[i, :, :])
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if abs(xmax - xmin) >= 20 and abs(ymax - ymin) >= 20:
                    boxes.append([xmin, ymin, xmax, ymax])
                    new_labels.append(labels[i])
                    new_masks.append(mask[i, :, :])
            except ValueError:
                continue

        if len(new_labels) == 0:
            boxes.append([0, 0, 20, 20])
            new_labels.append(0)
            new_masks.append(mask[0, :, :])

        nmx = np.zeros((len(new_masks), self.width, self.height), dtype=np.uint8)
        for i, n in enumerate(new_masks):
            nmx[i, :, :] = n

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(new_labels, dtype=torch.int64)
        masks = torch.as_tensor(nmx, dtype=torch.uint8)

        final_label = np.zeros((self.width, self.height), dtype=np.uint8)
        first_channel = np.zeros((self.width, self.height), dtype=np.uint8)
        second_channel = np.zeros((self.width, self.height), dtype=np.uint8)
        third_channel = np.zeros((self.width, self.height), dtype=np.uint8)

        upperbody = [0, 1, 2, 3, 4, 5]
        lowerbody = [6, 7, 8]
        wholebody = [9, 10, 11, 12]

        for i in range(len(labels)):
            if labels[i] in upperbody:
                first_channel += new_masks[i]
            elif labels[i] in lowerbody:
                second_channel += new_masks[i]
            elif labels[i] in wholebody:
                third_channel += new_masks[i]

        first_channel = (first_channel > 0).astype("uint8")
        second_channel = (second_channel > 0).astype("uint8")
        third_channel = (third_channel > 0).astype("uint8")

        final_label = first_channel + second_channel * 2 + third_channel * 3
        conflict_mask = (final_label <= 3).astype("uint8")
        final_label = (conflict_mask) * final_label + (1 - conflict_mask) * 1
        target_tensor = torch.as_tensor(final_label, dtype=torch.int64)

        return image_tensor, target_tensor

    def __len__(self):
        return len(self.image_info)

    def name(self):
        return "AlignedDataset"

    def rle_decode(self, mask_rle, shape):
        """
        mask_rle: run-length as string formated: [start0] [length0] [start1] [length1]... in 1d array
        shape: (height,width) of array to return
        Returns numpy array according to the shape, 1 - mask, 0 - background
        """
        shape = (shape[1], shape[0])
        s = mask_rle.split()
        # gets starts & lengths 1d arrays
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
        starts -= 1
        # gets ends 1d array
        ends = starts + lengths
        # creates blank mask image 1d array
        shape = tuple(int(dim) for dim in shape)
        img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
        # sets mark pixles
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        # reshape as a 2d mask image
        return img.reshape(shape).T  # Needed to align to RLE direction

In [None]:
import torch.utils.data


def CreateDataset(opt):
    dataset = None
    dataset = AlignedDataset()

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset


class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            sampler=data_sampler(self.dataset,
                                 not opt.serial_batches, opt.distributed),
            num_workers=int(opt.nThreads),
            pin_memory=True)

    def get_loader(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)


def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return torch.utils.data.RandomSampler(dataset)

    else:
        return torch.utils.data.SequentialSampler(dataset)


def sample_data(loader):
    while True:
        for batch in loader:
            yield batch


class CustomTestDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            num_workers=int(opt.nThreads),
            pin_memory=True)

    def get_loader(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

In [None]:
#!cd /content/dataset/

In [None]:
import os
import sys
import time
import yaml
import cv2
import pprint
import traceback
import numpy as np

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchvision import models



def options_printing_saving(opt):
    os.makedirs(opt.logs_dir, exist_ok=True)
    os.makedirs(opt.save_dir, exist_ok=True)
    os.makedirs(os.path.join(opt.save_dir, "images"), exist_ok=True)
    os.makedirs(os.path.join(opt.save_dir, "checkpoints"), exist_ok=True)

    # Saving options in yml file
    option_dict = vars(opt)
    with open(os.path.join(opt.save_dir, "training_options.yml"), "w") as outfile:
        yaml.dump(option_dict, outfile)

    for key, value in option_dict.items():
        print(key, value)


def training_loop(opt):

    if opt.distributed:
        local_rank = int(os.environ.get("LOCAL_RANK"))
        # Unique only on individual node.
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cuda:0")
        local_rank = 0

    u_net = U2NET(in_ch=3, out_ch=4)
    if opt.continue_train:
        u_net = load_checkpoint(u_net, opt.unet_checkpoint)
    u_net = u_net.to(device)
    u_net.train()

    if local_rank == 0:
        with open(os.path.join(opt.save_dir, "networks.txt"), "w") as outfile:
            print("<----U-2-Net---->", file=outfile)
            print(u_net, file=outfile)

    if opt.distributed:
        u_net = nn.parallel.DistributedDataParallel(
            u_net,
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=False,
        )
        print("Going super fast with DistributedDataParallel")

    # initialize optimizer
    optimizer = optim.Adam(
        u_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
    )

    custom_dataloader = CustomDatasetDataLoader()
    custom_dataloader.initialize(opt)
    loader = custom_dataloader.get_loader()

    if local_rank == 0:
        dataset_size = len(custom_dataloader)
        print("Total number of images avaliable for training: %d" % dataset_size)
        writer = SummaryWriter(opt.logs_dir)
        print("Entering training loop!")

    # loss function
    weights = np.array([1, 1.5, 1.5, 1.5], dtype=np.float32)
    weights = torch.from_numpy(weights).to(device)
    loss_CE = nn.CrossEntropyLoss(weight=weights).to(device)

    pbar = range(opt.iter)
    get_data = sample_data(loader)

    start_time = time.time()
    # Main training loop
    for itr in pbar:
        data_batch = next(get_data)
        image, label = data_batch
        image = Variable(image.to(device))
        label = label.type(torch.long)
        label = Variable(label.to(device))

        d0, d1, d2, d3, d4, d5, d6 = u_net(image)

        loss0 = loss_CE(d0, label)
        loss1 = loss_CE(d1, label)
        loss2 = loss_CE(d2, label)
        loss3 = loss_CE(d3, label)
        loss4 = loss_CE(d4, label)
        loss5 = loss_CE(d5, label)
        loss6 = loss_CE(d6, label)
        del d1, d2, d3, d4, d5, d6

        total_loss = loss0 * 1.5 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6

        for param in u_net.parameters():
            param.grad = None

        total_loss.backward()
        if opt.clip_grad != 0:
            nn.utils.clip_grad_norm_(u_net.parameters(), opt.clip_grad)
        optimizer.step()

        if local_rank == 0:
            # printing and saving work
            if itr % opt.print_freq == 0:
                pprint.pprint(
                    "[step-{:08d}] [time-{:.3f}] [total_loss-{:.6f}]  [loss0-{:.6f}]".format(
                        itr, time.time() - start_time, total_loss, loss0
                    )
                )

            if itr % opt.image_log_freq == 0:
                d0 = F.log_softmax(d0, dim=1)
                d0 = torch.max(d0, dim=1, keepdim=True)[1]
                visuals = [[image, torch.unsqueeze(label, dim=1) * 85, d0 * 85]]
                board_add_images(writer, "grid", visuals, itr)

            writer.add_scalar("total_loss", total_loss, itr)
            writer.add_scalar("loss0", loss0, itr)

            if itr % opt.save_freq == 0:
                save_checkpoints(opt, itr, u_net)

    print("Training done!")
    if local_rank == 0:
        itr += 1
        save_checkpoints(opt, itr, u_net)


if __name__ == "__main__":

    opt = parser()

    if opt.distributed:
        if int(os.environ.get("LOCAL_RANK")) == 0:
            options_printing_saving(opt)
    else:
        options_printing_saving(opt)

    try:
        if opt.distributed:
            print("Initialize Process Group...")
            torch.distributed.init_process_group(backend="nccl", init_method="env://")
            synchronize()

        set_seed(1000)
        training_loop(opt)
        cleanup(opt.distributed)
        print("Exiting..............")

    except KeyboardInterrupt:
        cleanup(opt.distributed)

    except Exception:
        traceback.print_exc(file=sys.stdout)
        cleanup(opt.distributed)

name training_cloth_segm_u2net_exp1
image_folder /content/dataset/train/
df_path /content/dataset/train.csv
distributed False
isTrain True
fine_width 768
fine_height 768
mean 0.5
std 0.5
batchSize 2
nThreads 2
max_dataset_size inf
serial_batches False
continue_train True
unet_checkpoint prev_checkpoints/cloth_segm_unet_surgery.pth
save_freq 1000
print_freq 10
image_log_freq 100
iter 100000
lr 0.0002
clip_grad 5
logs_dir logs/training_cloth_segm_u2net_exp1
save_dir results/training_cloth_segm_u2net_exp1
----checkpoints loaded from path: prev_checkpoints/cloth_segm_unet_surgery.pth----
dataset [AlignedDataset] was created


100%|██████████| 45195/45195 [00:04<00:00, 10809.42it/s]

Total number of images avaliable for training: 45195
Entering training loop!





'[step-00000000] [time-70.798] [total_loss-10.427507]  [loss0-1.401052]'
results/training_cloth_segm_u2net_exp1/checkpoints/itr_00000000_u2net.pth
'[step-00000010] [time-95.518] [total_loss-10.259898]  [loss0-1.646647]'
'[step-00000020] [time-122.206] [total_loss-4.788110]  [loss0-0.660185]'
'[step-00000030] [time-144.509] [total_loss-5.270166]  [loss0-0.667101]'
'[step-00000040] [time-172.507] [total_loss-3.836240]  [loss0-0.484880]'
'[step-00000050] [time-197.749] [total_loss-3.102251]  [loss0-0.481346]'
'[step-00000060] [time-227.452] [total_loss-2.119591]  [loss0-0.292702]'
'[step-00000070] [time-245.729] [total_loss-9.094823]  [loss0-1.066927]'
'[step-00000080] [time-260.585] [total_loss-4.226876]  [loss0-0.604385]'
'[step-00000090] [time-282.482] [total_loss-2.408255]  [loss0-0.330811]'
'[step-00000100] [time-302.038] [total_loss-1.885244]  [loss0-0.267589]'
'[step-00000110] [time-320.561] [total_loss-1.577716]  [loss0-0.231840]'
'[step-00000120] [time-341.656] [total_loss-3.8194