In [None]:
!pip install torch_snippets

In [None]:
from torch_snippets import *
import h5py
from scipy import io
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from glob import glob 
from tqdm import tqdm 
from torch.optim import Adam, lr_scheduler

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class CrowdCountingData(Dataset):
  def __init__(self, image_dir, heat_map_dir, ground_truth_dir, transformer):
    self.image_dir = image_dir
    self.heat_map_dir = heat_map_dir
    self.ground_truth_dir = ground_truth_dir
    self.img_list = stems(image_dir, silent = True)
    self.transformer = transformer

  def __len__(self):
    return len(self.img_list)
  
  def __getitem__(self, index):
    _stem = self.stems[index]
    img = f'{self.image_dir}/{_stem}.jpg'
    hp =  f'{self.heat_map_dir}/{_stem}.h5'
    gt = f'{self.ground_truth_dir}/{_stem}.mat'

    pts = io.loadmat(gt)
    pts = len(pts['image info'][0,0][0,0][0])
    print(pts)

    image = read(img, 1)
    with h5py.File(hp, 'r') as hf:
      gt = hf['density'][:]
    gt = resize(gt, 1/8)*64

    return image.copy(), gt.copy(), pts
  
  def collate_fn(self, batch):
    ims, gts, pts = list(zip(*batch))
    ims = torch.cat([self.transformer(im)[None] for im in ims]).to(device)
    gts = torch.cat([self.copy(gt)[None] for gt in gts]).to(device)
    pts = torch.tensor(pts).to(device)

    return ims, gts, pts

In [None]:
def make_layers(cfg, in_channels = 3, batch_norm = False, dilation = False):
  if dilation:
    d_rate = 2
  else:
    d_rate = 1
  
  layers = []
  for v in cfg:
    if v == 'M':
      layers += [nn.MaxPool2d(kernel_size=2, stride = 2)]
    else:
      conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = d_rate)
      if batch_norm:
        layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)]
      else:
        layers += [conv2d, nn.ReLU(inplace = True)]
      in_channels = v
  return nn.Sequential(*layers)

class CSRNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.frontend_features = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
    self.backend_features = [512,512,512,256,128,64]
    self.frontend = make_layers(self.frontend_features,batch_norm = True)
    self.backend = make_layers(self.backend_features, in_channels = 512, batch_norm = True, dilation = True)
    self.output_layer = nn.Conv2d(64, 1, kernel_size = 1)
  
  def forward(self, x):
    x = self.frontend(x)
    x = self.backend(x)
    x = self.output_layer(x)
    return x

In [None]:
def train_batch(model, data, optimizer, criterion):
  model.train()
  optimizer.zero_grad()
  ims, gts, pts = data 
  _gts = model(ims)
  loss = criterion(_gts, gts)
  loss.backward()
  optimizer.step()
  pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
  return loss.item(), pts_loss.item()

def valid_batch(model, data, criterion):
  model.eval()
  ims, gts, pts = data 
  _gts = model(ims)
  loss = criterion(_gts, gts)
  pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
  return loss.item(), pts_loss.item()

def get_model():
  model = CSRNet().to(device)
  optimizer = Adam(model.parameters(), lr = 1e-6)
  scheduler = lr_scheduler(optimizer, step_size = 10, gamma = 0.1)
  return model, optimizer, scheduler

def save_plot(tr_list, val_list, title):
  plt.plot(tr_list, label = "Training")
  plt.plot(val_list, label = "Validation")
  plt.xlabel("Epoch")
  plt.ylabel("Loss")
  plt.legend()
  plt.title(title)
  plt.show()
  plt.savefig(f"{title}.png")
  print(f"Saved {title}.png")

def run(epoch_range, model_state_dict = None):
  transformer = transforms.Compose([transforms.ToTensor()])

  image_b_dir = ""
  image_folder_dir = ""
  heat_map_dir = ""
  ground_truth_dir = ""

  vl_image_b_dir = ""
  vl_image_folder_dir = ""
  vl_heat_map_dir = ""
  vl_ground_truth_dir = ""

  tr_set = CrowdCountingData()
  vr_set = CrowdCountingData()
  tr_loader = DataLoader(tr_set, batch_size = 16, shuffle = True, collate_fn = tr_set.collate_fn)
  vr_loader = DataLoader(vr_set, batch_size = 16, shuffle = True, colalte_fn = vr_set.collate_fn)

  model, optimizer, scheduler = get_model()
  criterion = nn.MSELoss()

  training_loss = []
  validation_loss = []
  min_validation_loss = 1
  print("---------------------------------------------------------")
  for epoch in range(epoch_range):
    tr_loss = []
    for bx, data in enumerate(tqdm(tr_loader, desc = "TRAINING")):
      loss, pt = train_batch(model, data, optimizer, criterion)
      tr_loss.append(loss)
    tr_loss_value = np.mean(tr_loss_value)
    training_loss.append(tr_loss_value)

    vl_loss = []
    for bx, data in enumerate(tqdm(vr_loader, desc = "VALIDATION")):
      loss, pt = valid_batch(model, data, criterion)
      vl_loss.append(loss)
    vl_loss_value = np.mean(vl_loss)
    validation_loss.append(vl_loss_value)

    print("\nEpoch: {}/{} | Average Training Loss: {:.4f} | Average Validation Loss: {:.4f}".format(
      epoch+1,
      epoch_range,
      tr_loss_value, 
      vl_loss_value,
    ))

    if min_validation_loss > vl_loss_value:
      min_validation_loss = vl_loss_value
      torch.save(model.state_dict(),f'Epoch_{epoch+1}_model.pth')
      print("New Model Saved!")

  save_plot(training_loss, validation_loss, "LOSS")

In [None]:
if __name__ == "__main__":
  run(100)