from numpy import loadtxt
import unittest
import scipy.optimize as opt
class SasData(object):
"""A data object for holding 1-d Q versus I SAS data.
Class has a series of methods for initialising and
doing basic operations on SAS data. The root class
is very simple and only contains the Q and I lists
providing simple addition, multiplication, length
and string operations.
"""
def __init__(self, q_vals, i_vals):
"""Initializing the SasData object.
Just passes two lists to the initializing object which
are then stored internally. Possibly an argument for
including the units of Q in the root object
"""
assert type(q_vals) == list
assert type(i_vals) == list
assert len(q_vals) == len(i_vals), 'q and i not the same length'
self.q = q_vals
self.i = i_vals
def __len__(self):
return len(self.q)
def __add__(self, other):
"""Test whether other is SasData or float/int and add together.
The basic add operation covers two specific cases. The first is
where two SasData objects are added together where the wish is
for the intensities of both to be combined. The second cases is
when adding (or more likely subtracting) a numeric value (float
or int) which is handled separately. Probably both of these could
be written neater and faster.
Inputs must required are a SasData object and either a SasData object or
an int or float. Returns a new SasData object.
"""
assert isinstance(self, SasData)
assert (type(other) is int or type(other)
is float or isinstance(other, SasData))
# addition of two SasData objects
if isinstance(other, SasData):
assert len(self) == len(other), 'datasets not the same length'
assert self.q == other.q, 'q values not the same'
# initialize a SasData object with zero for all intensities
out = SasData(self.q, ([0] * len(self)))
for j in range(len(self)):
out.i[j] = self.i[j] + other.i[j]
return out
# addition of a float or int to SasData objects
elif type(other) is int or type(other) is float:
out = SasData(self.q, ([0]*len(self)))
for j in range(len(self)):
out.i[j] = self.i[j] + other
return out
def __mul__(self, other):
"""Basic mutplication function for SasData objects.
Requires a SasData object and an int or a float. Returns a new
SasData object.
"""
assert isinstance(self, SasData)
assert type(other) is int or type(other) is float
# initialize a SasData object with zero for all intensities
out = SasData(self.q, (([0] * len(self))))
for j in range(len(self)):
out.i[j] = self.i[j] * other
return out
def guinier(q, param):
"""Returns the intensities according to Guinier model for a sphere
Takes a q vector and three parameters (I(0), Rg and background
in param[0], param[1], and param[2] respectively) and returns
calculated values of I for each value of Q. Requires a simple list
of q values as an argument. Requires the correct number of
parameters to be passed to it at the moment.
"""
assert len(param) == 3, 'Guinier requires three parameters'
assert len(q) <> 0, 'Zero length q vector'
return param[0]*exp((-1/3)*(param[1]**2)*(q**2)) + param[2]
def guinier_residuals(param, q, i):
"""Routine for Guinier residual calculation.
"""
assert type(q) is list, 'Need Q values to calculate residuals'
assert type(i) is list, 'Need I values to calculate residuals'
assert len(i) == len(q), 'Q and I not the same length?'
assert len(q) > 0, 'Data seems to have zero length'
assert type(param) is list, 'Need a list of parameters'
err = i - guinier(q, param)
return err
def fit_model(data, model):
"""Generic fitting routine for SasData.
Currently limited to Guinier models because that is the only
model implemented at the moment.
"""
assert isinstance(data, SasData)
param_0 = array([1,1,0]) # some reasonable initial values for Guinier
least_square_fit = opt.leastsq(
guinier_residuals, param_0, args=(data.q, data.i))
class TestSasData(unittest.TestCase):
def setUp(self):
self.zero_to_nine = range(10)
self.test_string = 'a string'
self.nine_to_zero = range(9,-1,-1)
self.test_zero = [0]
self.zero_to_twenty = range(20)
self.eighteen_to_zero = range(18,-2,-2)
self.thirteen_to_four = range(13,3,-1)
self.test_floats = [(9.*0.25), 2., (7.*0.25), 1.5, 1.25, 1., 3.*0.25,
0.5, 0.25, 0.]
self.test_data_ranges = SasData(self.zero_to_nine, self.nine_to_zero)
self.test_data_zeros = SasData(self.test_zero, self.test_zero)
def tearDown(self):
self.test_data_ranges = SasData([],[])
self.test_data_zeros = SasData([],[])
test_add = SasData([],[])
def test_init(self):
self.assertEqual(self.test_data_ranges.q, self.zero_to_nine)
self.assertEqual(self.test_data_ranges.i, self.nine_to_zero)
self.assertRaises(
AssertionError, SasData, self.test_string, self.test_zero)
self.assertRaises(
AssertionError, SasData, self.test_zero, self.zero_to_twenty)
def test_len(self):
self.assertTrue(len(self.test_data_ranges) == 10)
self.assertTrue(len(self.test_data_zeros) == 1)
test_data = SasData([],[])
self.assertTrue(len(test_data) == 0)
def test_add(self):
test_add = SasData(self.zero_to_nine, self.nine_to_zero)
test_add = self.test_data_ranges + self.test_data_ranges
self.assertEqual(self.eighteen_to_zero, test_add.i)
self.assertEqual(self.zero_to_nine, test_add.q)
test_add = SasData(self.zero_to_nine, self.nine_to_zero)
test_add = self.test_data_ranges + 4
self.assertEqual(self.thirteen_to_four, test_add.i)
self.assertEqual(self.zero_to_nine, test_add.q)
self.assertRaises(AssertionError,
self.test_data_ranges.__add__, self.test_string)
def test_mul(self):
# test simple multiplication
test_mul = self.test_data_ranges * 2
self.assertEqual(self.eighteen_to_zero, test_mul.i)
self.assertEqual(self.zero_to_nine, test_mul.q)
# test multiplication by zero
test_mul = self.test_data_ranges * 0
self.assertEqual([0] * len(self.test_data_ranges), test_mul.i)
self.assertEqual(self.zero_to_nine, test_mul.q)
# test multiplication by a float
test_mul = self.test_data_ranges * 0.25
self.assertEqual(self.test_floats, test_mul.i)
self.assertEqual(self.zero_to_nine, test_mul.q)
if __name__ == '__main__':
unittest.main()