In [1]:
import datetime
import time
import os
import torch
import torch.utils.data.dataloader as DataLoader
import sys
import argparse
import transforms as T
import matplotlib.pyplot as plt
import numpy as np

from model import SimpleNet
from Src.ComplexValuedAutoencoderMain_Torch import end_to_end_Net
from dataset import ToFDataset
from train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
from util.logconf import logging


log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

class PresetTrain:
    def __init__(self, crop_size, hflip_prob=0.5):
        
        trans = []
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend([
            T.RandomCrop(crop_size),
            
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)


class PresetEval:
    def __init__(self, crop_size):
        trans = []
        trans.extend([
            T.CenterCrop(crop_size),
#             T.RandomCrop(crop_size),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)

def get_transform(train):
    crop_size = 256

    if train:
        return PresetTrain(crop_size)
    else:
        return PresetEval(crop_size)

class TrainingApp:
    def __init__(self, args=None):
    
        self.lr = 0.0001
        self.path = './data'
        self.batch_size = 72
        self.epochs = 200
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.num_workers = 8
        self.model = self.initModel()
        self.optimizer = self.initOptimizer(self.lr)
       
      

    def initModel(self):
        model = end_to_end_Net(1,1,1,bilinear=True)
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(
                torch.cuda.device_count()))
#             if torch.cuda.device_count() > 1:
#                 model = torch.nn.DataParallel(model, device_ids=[0,1,2])
            model = model.to(self.device)
        return model

    def initOptimizer(self, lr):
        optimizer = torch.optim.Adam(self.model.parameters(), lr)
        return optimizer
    

    def initTrainDL(self):
        train_dataset = ToFDataset(self.path, train=True, transforms=get_transform(train=True))
#         if self.distributed:
#             train_sampler = data.distributed.DistributedSampler(train_dataset)
#         else:
#             train_sampler = data.RandomSampler(train_dataset)
        
#         if self.use_cuda:
#             batch_size *= torch.cuda.device_count()

        train_DL = torch.utils.data.DataLoader(train_dataset,
                              batch_size=self.batch_size, 
                              shuffle=True,
                              num_workers=self.num_workers,
                              pin_memory=True,)
        return train_DL

    def initValDL(self):
        val_dataset = ToFDataset(self.path, train=False, transforms=get_transform(train=False))
#         if self.distributed:
#             val_sampler = data.distributed.DistributedSampler(val_dataset)
#         else:
#             val_sampler = data.RandomSampler(val_dataset)

#         if self.use_cuda:
#             batch_size *= torch.cuda.device_count()

        val_DL = torch.utils.data.DataLoader(val_dataset,
                            batch_size=self.batch_size,
                            num_workers=self.num_workers,
                            pin_memory=self.use_cuda,
                            shuffle=False)
        return val_DL
    

    def showPlt(self, train_losses, val_losses):
        epochs = np.arange(1, self.epochs+1)
        plt.figure()
        plt.plot(epochs, train_losses, label='train_losses')
        plt.xlabel('Epochs')
        plt.ylabel('train_losses')
        plt.title('Training Loss')
        plt.legend()
        
        plt.figure()
        plt.plot(epochs, val_losses, label='val_losses')
        plt.xlabel('Epochs')
        plt.ylabel('val_losses')
        plt.title('Validation Loss')
        plt.legend()
        
        plt.show()    
        
    def main(self):
        # log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
        results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
        
        train_DL = self.initTrainDL()
        val_DL = self.initValDL()
        min_loss = 10000

        start_time = time.time()
        train_losses = []
        val_losses = []
        self.lr_scheduler = create_lr_scheduler(self.optimizer, num_step=len(train_DL), epochs=self.epochs, warmup=True)
        
        for epoch_ndx in range(1, self.epochs + 1):
                
            train_loss, lr = train_one_epoch(self.model, self.optimizer, train_DL, self.device, epoch_ndx, self.lr_scheduler, scaler=None)
            
            val_loss = evaluate(self.model, val_DL, self.device)
            
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            save_file = {"model": self.model.state_dict(),
                        "optimizer": self.optimizer.state_dict(),
                        "lr_scheduler": self.lr_scheduler.state_dict(),
                        "epoch": epoch_ndx,
                        # "args": args
                        }
            
            if val_loss < min_loss:
                min_loss = val_loss
                torch.save(save_file, "save_weights/best_model.pth")
        
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print("Training time {}".format(total_time_str))
        self.showPlt(train_losses, val_losses)


if __name__ == '__main__':
    import argparse
    if not os.path.exists('./save_weights'):
        os.makedirs('./save_weights')
        
#     parser = argparse.ArgumentParser(description=__doc__)
#     parser.add_argument('--data-path', default='./data', help='dataset')
#     parser.add_argument('--device', default='cuda', help='device')
#     parser.add_argument('-b', '--batch-size', default=4, type=int, help='images per gpu, the total batch size is $NGPU x batch_size')
#     parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
#     parser.add_argument('--sync_bn', type=bool, default=False, help='whether using SyncBatchNorm')
#     parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
#     parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate, 0.01*GPU number')
#     parser.add_argument('--output-dir', default='./multi_train', help='path where to save')
    
#     parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes')
#     parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    
#     args = parser.parse_args(args=['--data-path', './data'])
    
    TrainingApp().main()


2023-12-04 16:58:02,476 INFO     pid:41854 __main__:078:initModel Using CUDA; 3 devices.


end_to_end_Net
clear_tof_amp_npys size:3000
clear_tof_pha_npys size:3000
fog_tof_amp_npys_path size:3000
fog_tof_pha_npys_path size:3000
clear_tof_amp_npys size:500
clear_tof_pha_npys size:500
fog_tof_amp_npys_path size:500
fog_tof_pha_npys_path size:500
outside:input size: torch.Size([72, 1, 256, 256])
image1: tensor(-248.2685+40.7867j, device='cuda:0')
image2: tensor(-248.6225+42.4000j, device='cuda:0')
new_x torch.Size([72, 1, 256, 256])
new_x1 tensor(-248.2685+40.7867j, device='cuda:0')
new_x2 tensor(-248.6225+42.4000j, device='cuda:0')
	In Model: input size torch.Size([72, 1, 256, 256]) output size torch.Size([72, 1, 256, 256])
outside:output size: torch.Size([72, 1, 256, 256])
Epoch: [1]  [ 0/42]  eta: 0:04:30  lr: 0.000002  loss: 1472.5143 (1472.5143)  time: 6.4368  data: 3.1725  max mem: 27011
outside:input size: torch.Size([72, 1, 256, 256])
image1: tensor(43.4663+249.0388j, device='cuda:0')
image2: tensor(43.4504+248.9893j, device='cuda:0')
new_x torch.Size([72, 1, 256, 256])

KeyboardInterrupt: 