ICCV_model 3: PoCCA with new trick, without merging 2 branches + new FPS with selected first centroids

In [1]:
import os
import torch

import sys
sys.path.append("..") 

from models.iccv_model_3 import SimAttention_ICCV_3
from data.shapenet_loader import ShapeNetCLS

# set gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# hy-paras
BATCS_SIZE = 16
PATCH_NUM = 8
EPOCHS = 50

# data
root = r'/p/home/jusers/li39/juwels/H_AttentionProject'
dataset = ShapeNetCLS(root, 1024)
trainDataLoader = torch.utils.data.DataLoader(dataset, 
        batch_size=BATCS_SIZE, 
        shuffle=True, 
        num_workers=8, 
        pin_memory=True)

# set model
model = SimAttention_ICCV_3(patch_num=PATCH_NUM)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)

# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3)

def train_1_epoch(model, optimizer, data_loader, device):
        model.train()
        mean_loss = torch.zeros(1).to(device)
        optimizer.zero_grad()
        for step, data in enumerate(data_loader):
                aug1, aug2 = data
                loss = model(aug1.to(device), aug2.to(device))
                loss = loss.mean()
                loss.backward()
                mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean losses
                optimizer.step()
                optimizer.zero_grad()
        return mean_loss.item()
    
    
print('start training!')
for epoch in range(0, EPOCHS):
    print('--------Epoch {} is running--------'.format(epoch))
    loss = train_1_epoch(model, optimizer, trainDataLoader, device)
    print("Loss: ", loss)
    print('\n')
    scheduler.step()
    if epoch%5==0 or (epoch+1)==EPOCHS:
        weight_name = 'cls_dg_iccv3_50_fps_' + str(epoch) + '.pth'
        weight_path = os.path.join(r'/p/home/jusers/li39/juwels/H_AttentionProject/weights/iccv_3', weight_name)
        torch.save(model.state_dict(), weight_path)
        print('Model Saved!')

  from .autonotebook import tqdm as notebook_tqdm


start training!
--------Epoch 0 is running--------
Loss:  0.40938600897789


Model Saved!
--------Epoch 1 is running--------
Loss:  0.2807232737541199


--------Epoch 2 is running--------
Loss:  0.2849440574645996


--------Epoch 3 is running--------
Loss:  0.20998677611351013


--------Epoch 5 is running--------
Loss:  0.23541803658008575


Model Saved!
--------Epoch 6 is running--------
Loss:  0.24000932276248932


--------Epoch 7 is running--------
Loss:  0.21021023392677307


--------Epoch 8 is running--------
Loss:  0.2198854386806488


--------Epoch 9 is running--------
Loss:  0.2280462384223938


--------Epoch 10 is running--------
Loss:  0.19354966282844543


Model Saved!
--------Epoch 11 is running--------
Loss:  0.18366220593452454


--------Epoch 12 is running--------
Loss:  0.17937873303890228


--------Epoch 13 is running--------
Loss:  0.18415860831737518


--------Epoch 14 is running--------
Loss:  0.17145465314388275


--------Epoch 15 is running--------
Loss:  0.167742