In [1]:
import MDAnalysis as mda
import matplotlib.pyplot as plt
import math
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
graphite_water_psf = r'D:\all_programming\MSD\data_files\unwrapped\graphite_water.psf'
graphite_trajectory_dcd = r'D:\all_programming\MSD\data_files\unwrapped\output_per_20ps.dcd'

global_uni = mda.Universe(graphite_water_psf,graphite_trajectory_dcd)
Oxg_uni = global_uni.select_atoms('resname SPCE and name OW')
# Importing data



In [22]:
z_maxs = []

for frame in range(len(global_uni.trajectory)):
    global_uni.trajectory[frame]
    z_vals = []
    for atom in Oxg_uni.positions:
        z_vals.append(atom[2])
    z_maxs.append(max(z_vals))
# Appends the max z_val in each frame to z_maxs
# Generally lower than 65 until some weird extremes occur which we ignore

upper_z_bound = 65
partition_count = 4
z_partitions = []

for i in range(partition_count):
    z_partitions.append(upper_z_bound * (i / partition_count))
z_partitions.append(upper_z_bound)
# Creates partition list with upper and lower bound included

In [57]:
def calculate_msd(position_dict):

    x_vals = np.array(position_dict['x_vals'])
    y_vals = np.array(position_dict['y_vals'])
    z_vals = np.array(position_dict['z_vals'])
    # Import needed data

    # Number of time points
    num_points = len(x_vals)
    
    # Initialize MSD array
    msd = np.zeros(num_points)
    
    # Calculate displacements and MSD
    for dt in range(1, num_points):
        displacements = []
        for t in range(num_points - dt):
            dx = x_vals[t + dt] - x_vals[t]
            dy = y_vals[t + dt] - y_vals[t]
            dz = z_vals[t + dt] - z_vals[t]
            displacement_squared = dx**2 + dy**2 + dz**2
            displacements.append(displacement_squared)
        
        msd[dt] = np.mean(displacements)
    
    return [msd]
# Honestly, mystery function... no clue what it does

In [52]:
atom_initial_part_list = []

for atom_num in range(len(Oxg_uni)):
    global_uni.trajectory[0]
    Oxg_positions = Oxg_uni.positions
    current_z = Oxg_positions[atom_num][2]
    
    for i in range(len(z_partitions)):
        if current_z < z_partitions[i]:
            atom_initial_part_list.append(i)
            break
# Determining which part atoms are initially in

In [58]:
atom_dict = {}

for atom_num in range(len(Oxg_uni)):
    atom_dict[atom_num] = {'x_vals':[],'y_vals':[],'z_vals':[]}
    
    try:
        for frame in range(len(global_uni.trajectory)):
            global_uni.trajectory[frame]
            Oxg_positions = Oxg_uni.positions
            current_z = Oxg_positions[atom_num][2]

            for i in range(len(z_partitions)):
                if current_z < z_partitions[i]:
                    current_part = i
                    break
                
            if current_part == atom_initial_part_list[atom_num]:
                atom_dict[atom_num]['x_vals'].append(Oxg_positions[atom_num][0])
                atom_dict[atom_num]['y_vals'].append(Oxg_positions[atom_num][1])
                atom_dict[atom_num]['z_vals'].append(current_z)
            else:
                raise NotImplementedError
    except NotImplementedError:
        continue

msd_dict = {}
for atom_num in range(len(atom_dict.keys())):
    msd_dict[atom_num] = calculate_msd(atom_dict[atom_num])

In [69]:
for key in msd_dict.keys():
    print(f'{key}: {msd_dict[key]}')

0: [array([0.        , 8.67543823])]
1: [array([ 0.        , 10.43933168])]
2: [array([ 0.        , 18.69044643, 22.30996668, 16.81509123, 30.08075721,
       33.2751235 , 26.53958395, 30.29516961, 30.54043359, 38.63094153,
       44.93698596])]
3: [array([  0.        ,  34.94576473,  47.16170447,  56.69380956,
        69.27850286,  84.40881544, 106.82011209, 113.48640279,
       122.90223722, 138.00338876, 143.75036394, 131.1092411 ,
       131.95646698, 115.42700255, 134.56633024, 144.60158159,
       140.02289916, 142.00607864, 149.41520567, 233.199161  ,
       269.11886228, 273.60254624, 415.74021298])]
4: [array([  0.        ,  28.07912346,  67.35866135, 106.99245828,
       151.96217255, 205.67614083, 267.38596075, 328.77395547,
       375.02383396, 417.86701856, 463.56352363, 500.4625936 ,
       517.11798405, 535.58560601, 531.60775374, 499.59056062,
       428.13796282, 397.60722098, 346.65774957, 323.61071091,
       354.02685652, 304.93059335, 411.04227636])]
5: [array([ 0.