In [None]:
import albumentations as A
from torch.utils.data import Dataset
def get_train_augs():
  return A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5)
  ])

def get_valid_augs():
  return A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
  ])  
class SegmentationDataset(Dataset):

  def __init__(self,df,augmentations):
    
    self.df = df
    self.augmentations = augmentations

  def __len__(self):
    return len(self.df)

  def __getitem__(self,idx):

    row = self.df.iloc[idx]

    image_path = row.images
    mask_path = row.masks

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) #(h,w,c)
    mask = np.expand_dims(mask, axis=-1)

    if self.augmentations:
      data = self.augmentations(image=image, mask=mask)
      image = data["image"]
      mask= data["mask"]

    #(h,w,c)->(c,h,w)

    image = np.transpose(image, (2,0,1)).astype(np.float32)
    mask= np.transpose(mask, (2,0,1)).astype(np.float32)

    image = torch.Tensor(image)/255.0
    mask = torch.round(torch.Tensor(mask)/255.0)
    
    return image,mask
trainset = SegmentationDataset(train_df,get_train_augs())
validset = SegmentationDataset(valid_df,get_valid_augs())

idx=3

image,mask = trainset[idx]
#helper.show_image(image,mask)
helper.imshow(image)

from torch.utils.data import DataLoader
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle= True)
validloader = DataLoader(validset, batch_size=BATCH_SIZE)
print(f"total no. of batches in trainloader : {len(trainloader)}")
print(f"total no. of batches in validloader : {len(validloader)}")
for image, mask in trainloader:
  break

print(f"One batch image shape: {image.shape}")
print(f"One batch mask shape: {mask.shape}")

from torch import nn
import segmentation_models_pytorch as smp 
from segmentation_models_pytorch.losses import DiceLoss

from torch import nn
import segmentation_models_pytorch as smp 
from segmentation_models_pytorch.losses import DiceLoss

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel, self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels=3,
        classes = 1,
        activation=None
    )

  def forward(self, images,mask=None):

    logits = self.arc(images)

    if mask != None:
      loss1 = DiceLoss(mode="binary")(logits,mask)
      loss2= nn.BCEWithLogitsLoss()(logits,mask)
      return logits, loss1*loss2

    return logits
model = SegmentationModel()
model.to(DEVICE);

def train_fn(data_loader, model, optimizer):

  model.train()
  total_loss = 0.0

  for images, masks in tqdm(data_loader):

    images= images.to(DEVICE)
    masks = masks.to(DEVICE)

    optimizer.zero_grad()
    logits,loss = model(images,masks)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  return total_loss/len(data_loader)
def eval_fn(data_loader, model):

  model.eval()
  total_loss=0.0

  with torch.no_grad():
    for images,masks in tqdm(data_loader):

      images= images.to(DEVICE)
      masks= masks.to(DEVICE)

      logits,loss=model(images,masks)

      total_loss += loss.item()

  return total_loss/len(data_loader)

optimizer = torch.optim.Adam(model.parameters(),lr=LR)
best_valid_loss = np.Inf

for i in range(EPOCHS):

  train_loss = train_fn(trainloader,model,optimizer)
  valid_loss = eval_fn(validloader,model)

  if valid_loss < best_valid_loss:
    torch.save(model.state_dict(),"best_model.pt")
    print("Saved model")
    best_valid_loss = valid_loss

  print(f"Epoch : {i+1} Train_loss : {train_loss} Valid_loss : {valid_loss}")

idx = 20

model.load_state_dict(torch.load("/content/best_model.pt"))

image, mask = validset[idx]

logits_mask = model(image.to(DEVICE).unsqueeze(0)) #(C,H,W)-> (1,C,H,W)

pred_mask = torch.sigmoid(logits_mask)

pred_mask = (pred_mask>0.5)*1

f, (ax1, ax2,ax3) = plt.subplots(1, 3, figsize=(10,5))
        
ax1.set_title('IMAGE')
ax1.imshow(image[0])

ax2.set_title('GROUND TRUTH')
ax2.imshow(mask[0],cmap = 'gray')

ax3.set_title('GROUND TRUTH')
ax3.imshow(pred_mask.detach().cpu().squeeze(0)[0],cmap = 'gray')

In [None]:
def forward(x):
    return w * x + b
def criterion(yhat,y):
    return torch.mean((yhat-y)**2)
lr = 0.1
LOSS = []
def train_model(iter):
    
    # Loop
    for epoch in range(iter):
        
        # make a prediction
        Yhat = forward(X)
        
        # calculate the loss 
        loss = criterion(Yhat, Y)

        # Section for plotting
        get_surface.set_para_loss(w.data.tolist(), b.data.tolist(), loss.tolist())
        if epoch % 3 == 0:
            get_surface.plot_ps()
            
        # store the loss in the list LOSS
        LOSS.append(loss)
        
        # backward pass: compute gradient of the loss with respect to all the learnable parameters
        loss.backward()
        
        # update parameters slope and bias
        w.data = w.data - lr * w.grad.data
        b.data = b.data - lr * b.grad.data
        
        # zero the gradients before running the backward pass
        w.grad.data.zero_()
        b.grad.data.zero_()