Skip to content

Commit

Permalink
update_torchxla
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverRensu committed May 2, 2024
1 parent 765e8f5 commit da84628
Show file tree
Hide file tree
Showing 83 changed files with 9,104 additions and 3 deletions.
8 changes: 8 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/GKDv2.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions DiGPT_torchxla/21K-FT/.idea/other.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

135 changes: 135 additions & 0 deletions DiGPT_torchxla/21K-FT/engine_finetune_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging

import torch
import torch_xla.core.xla_model as xm
import torch_xla.test.test_utils as test_utils
import torchvision.transforms as transforms
from util import misc
from util import lr_sched
from typing import Iterable, Optional

from timm.data import Mixup
from timm.utils import accuracy
import numpy as np
from util.precision import get_autocast
import math
import sys
import time

# import wandb

def after_train_step(args, times, data_iter_step, epoch, num_batches_per_epoch,
losses, lr):
# NOTE loss is coarsely sampled, just master node and per log update
if args.rank == 0 and (data_iter_step % args.log_freq or data_iter_step == num_batches_per_epoch) == 0:
loss = losses.item()*args.accum_iter

percent_complete = 100.0 * data_iter_step // args.accum_iter / num_batches_per_epoch
samples_per_second = args.accum_iter * args.batch_size * args.world_size /times
loss_log = f"loss: {loss}"
logging.info(
f"Train Epoch: {epoch} ({percent_complete:.0f}%)] "
f"Batch (t): {times:.3f}, {samples_per_second:#g}/s"
f"LR: {lr:7f} " + loss_log
)

def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, writer,
mixup_fn=None, args=None):
model.train(True)

accum_iter = args.accum_iter
autocast = get_autocast()
optimizer.zero_grad()
total_train_samples, total_train_batches, total_train_loss, total_train_lr = 0, 0, 0.0, 0.0
num_batches_per_epoch = len(data_loader) // args.accum_iter
#batch_time_m = AverageMeter()
end = time.time()
# for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
for data_iter_step, (samples, targets) in enumerate(data_loader):
if data_iter_step % accum_iter == 0:
if data_iter_step%200==0:
middle_step = data_iter_step//200*200 + 100
lr_sched.adjust_learning_rate(optimizer, middle_step / len(data_loader) + epoch, args)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
samples = samples.to(device)
targets = targets.to(device)
with autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss /= accum_iter
lr = optimizer.param_groups[-1]["lr"]

total_train_loss += loss.item()
total_train_lr += lr
total_train_batches += 1
loss.backward()
if (data_iter_step + 1) % accum_iter != 0:
xm.mark_step()
else:
#xm.reduce_gradients(optimizer)
optimizer.step()
xm.mark_step()
optimizer.zero_grad()
#batch_time_m.update(time.time() - end)

after_train_step_args = [args, time.time() - end, data_iter_step, epoch, num_batches_per_epoch, loss, lr]
xm.add_step_closure(after_train_step, after_train_step_args)
end = time.time()
train_loss = total_train_loss / total_train_batches
train_lr = total_train_lr / total_train_batches
train_loss = xm.mesh_reduce('train_loss', train_loss, np.mean)
train_lr = xm.mesh_reduce('train_lr', train_lr, np.mean)

return {"train_loss":train_loss, "train_lr":train_lr}


@torch.no_grad()
def evaluate(data_loader, model, epoch, device, mask_num=None, num_patches=None):
criterion = torch.nn.CrossEntropyLoss()

model.eval()
autocast = get_autocast()
total_test_samples, total_test_batches, total_test_loss, total_test_acc1, total_test_acc5 = 0, 0, 0.0, 0.0, 0.0
for images, target in data_loader:
images = images.to(device)
target = target.to(device)
with autocast(), torch.no_grad():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))

batch_size = images.shape[0]
# metric_logger.update(loss=loss.item())
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

total_test_loss += loss
total_test_acc1 += acc1 * batch_size
total_test_acc5 += acc5 * batch_size
total_test_batches += 1
total_test_samples += batch_size
# test_acc1 = total_test_acc1.item() / total_test_samples
# test_acc5 = total_test_acc5.item() / total_test_samples
# test_loss = total_test_loss.item() / total_test_batches
# logging.info("Loss: {}, Top-1: {}, Top-5: {}".format(test_loss, test_acc1, test_acc5))

# gather
# metric_logger.synchronize_between_processes()
test_acc1 = total_test_acc1.item() / total_test_samples
test_acc5 = total_test_acc5.item() / total_test_samples
test_loss = total_test_loss.item() / total_test_batches
logging.info("Loss: {}, Top-1: {}, Top-5: {}".format(test_loss, test_acc1, test_acc5))
xm.add_step_closure(test_utils.print_test_update, args=(device, epoch, test_acc1))
test_acc1 = xm.mesh_reduce('test_acc1', test_acc1, np.mean)
test_acc5 = xm.mesh_reduce('test_acc5', test_acc5, np.mean)
test_loss = xm.mesh_reduce('test_loss', test_loss, np.mean)

if misc.is_main_process():
logging.info('* Acc@1 {top1:.3f} Acc@5 {top5:.3f} loss {losses:.3f}'
.format(top1=test_acc1, top5=test_acc5, losses=test_loss))
return test_acc1, test_loss, test_acc5

# return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
67 changes: 67 additions & 0 deletions DiGPT_torchxla/21K-FT/launch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
"""


import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO

import torch_xla.distributed.xla_multiprocessing as xmp


def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")

# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")

# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")

# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()


def main():
args = parse_args()

# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)

script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])

sys.argv = [args.script] + args.script_args

print('launching xla...')
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)


if __name__ == "__main__":
main()
Loading

0 comments on commit da84628

Please sign in to comment.