Skip to content

Commit

Permalink
Migrate Apex to PyTorch AMP
Browse files Browse the repository at this point in the history
  • Loading branch information
dtpreda committed Oct 17, 2023
1 parent 725a069 commit b33184f
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions src/adaVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import datetime

from torch.utils.data import Dataset, DataLoader
from apex.optimizers import FusedAdam
from apex import amp
from apex.fp16_utils import FP16_Optimizer
# from apex.optimizers import FusedAdam
# from apex import amp
# from apex.fp16_utils import FP16_Optimizer
from torch.cuda.amp import autocast, GradScaler
from transformers.modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AdamW, get_linear_schedule_with_warmup, Conv1D

Expand Down Expand Up @@ -224,22 +225,26 @@ def compute_loss(device, model, x_tokens, input_tokens, att_mask, loss_fn, beta,
else:
return loss, ce_loss, regularization_loss, mean, logvar

def train_step(device, model, optimizer, x_tokens, input_tokens, att_mask, loss_fn, beta, kl_rate, reg_loss_type, from_mean, fb):
def train_step(scaler, device, model, optimizer, x_tokens, input_tokens, att_mask, loss_fn, beta, kl_rate, reg_loss_type, from_mean, fb):
optimizer.zero_grad()
loss, ce_loss, reg_loss, _, _ = compute_loss(device, model, x_tokens, input_tokens, att_mask, loss_fn,
beta, kl_rate, reg_loss_type, weighted_sample=False, from_mean=from_mean, fb=fb)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0) # max_grad_norm=1.0
# loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # max_grad_norm=1.0
optimizer.step()
# output.append((loss.item(), ce_loss.mean().item(), reg_loss.item()))

with autocast():
loss, ce_loss, reg_loss, _, _ = compute_loss(device, model, x_tokens, input_tokens, att_mask, loss_fn,
beta, kl_rate, reg_loss_type, weighted_sample=False, from_mean=from_mean, fb=fb)

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # max_grad_norm=1.0

scaler.step(optimizer)
scaler.update()

if reg_loss_type == "adversarial":
## reg_loss: [discriminator loss, generator loss, KL loss]
reg_loss[2] = reg_loss[2].mean()
else:
reg_loss = reg_loss.mean()

return loss.item(), ce_loss.mean().item(), reg_loss

def train(args):
Expand Down Expand Up @@ -491,7 +496,8 @@ def train(args):

optimizer = AdamW(AdaVAE.parameters(), lr=args.lr, correct_bias=True)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
AdaVAE, optimizer = amp.initialize(AdaVAE, optimizer, opt_level=args.fp16_opt_level)
# AdaVAE, optimizer = amp.initialize(AdaVAE, optimizer, opt_level=args.fp16_opt_level)
scaler = GradScaler()

## load ckpt
if args.load:
Expand Down Expand Up @@ -833,7 +839,7 @@ def val_step(val_loader):
kl_rate = args.kl_rate / args.latent_size
else:
kl_rate = args.kl_rate
loss, ce_loss, regul_loss = train_step(device, AdaVAE, optimizer, x_ids, input_ids, attention_mask,
loss, ce_loss, regul_loss = train_step(scaler, device, AdaVAE, optimizer, x_ids, input_ids, attention_mask,
loss_fn, beta, kl_rate, args.reg_loss, False, args.fb)
if args.reg_loss == "adversarial":
d_loss, g_loss, kld = regul_loss[0].item(), regul_loss[1].item(), regul_loss[2].item()
Expand Down

0 comments on commit b33184f

Please sign in to comment.